20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/ADT/APSInt.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallPtrSet.h"
25#include "llvm/ADT/StringExtras.h"
26#include "llvm/ADT/TypeSwitch.h"
29using namespace firrtl;
33static Value
dropWrite(PatternRewriter &rewriter, OpResult old,
35 SmallPtrSet<Operation *, 8> users;
36 for (
auto *user : old.getUsers())
38 for (Operation *user : users)
39 if (
auto connect = dyn_cast<FConnectLike>(user))
40 if (connect.getDest() == old)
41 rewriter.eraseOp(user);
51 if (op->getNumRegions() != 0)
53 return mlir::isPure(op) || isa<NodeOp, WireOp>(op);
61 Operation *op = passthrough.getDefiningOp();
64 assert(op &&
"passthrough must be an operation");
65 Operation *oldOp = old.getOwner();
66 auto name = oldOp->getAttrOfType<StringAttr>(
"name");
68 op->setAttr(
"name", name);
76#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
84 auto resultType = type_cast<IntType>(op->getResult(0).getType());
85 if (!resultType.hasWidth())
87 for (Value operand : op->getOperands())
88 if (!type_cast<IntType>(operand.getType()).hasWidth())
95 auto t = type_dyn_cast<UIntType>(type);
96 if (!t || !t.hasWidth() || t.getWidth() != 1)
103static void updateName(PatternRewriter &rewriter, Operation *op,
108 assert((!isa<InstanceOp, RegOp, RegResetOp>(op)) &&
"Should never rename");
109 auto newName = name.getValue();
110 auto newOpName = op->getAttrOfType<StringAttr>(
"name");
113 newName =
chooseName(newOpName.getValue(), name.getValue());
115 if (!newOpName || newOpName.getValue() != newName)
116 rewriter.modifyOpInPlace(
117 op, [&] { op->setAttr(
"name", rewriter.getStringAttr(newName)); });
125 if (
auto *newOp = newValue.getDefiningOp()) {
126 auto name = op->getAttrOfType<StringAttr>(
"name");
129 rewriter.replaceOp(op, newValue);
135template <
typename OpTy,
typename... Args>
137 Operation *op, Args &&...args) {
138 auto name = op->getAttrOfType<StringAttr>(
"name");
140 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
148 if (
auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
149 return namableOp.hasDroppableName();
160static std::optional<APSInt>
162 assert(type_cast<IntType>(operand.getType()) &&
163 "getExtendedConstant is limited to integer types");
170 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
175 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
176 return APSInt(destWidth,
177 type_cast<IntType>(operand.getType()).isUnsigned());
185 if (
auto attr = dyn_cast<BoolAttr>(operand))
186 return APSInt(APInt(1, attr.getValue()));
187 if (
auto attr = dyn_cast<IntegerAttr>(operand))
188 return attr.getAPSInt();
196 return cst->isZero();
223 Operation *op, ArrayRef<Attribute> operands,
BinOpKind opKind,
224 const function_ref<APInt(
const APSInt &,
const APSInt &)> &calculate) {
225 assert(operands.size() == 2 &&
"binary op takes two operands");
228 auto resultType = type_cast<IntType>(op->getResult(0).getType());
229 if (resultType.getWidthOrSentinel() < 0)
233 if (resultType.getWidthOrSentinel() == 0)
234 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
240 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
242 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
243 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
244 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
245 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
246 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
250 int32_t operandWidth;
253 operandWidth = resultType.getWidthOrSentinel();
258 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
262 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
273 APInt resultValue = calculate(*lhs, *rhs);
278 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
280 assert((
unsigned)resultType.getWidthOrSentinel() ==
281 resultValue.getBitWidth());
294 Operation *op, PatternRewriter &rewriter,
295 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
297 if (op->getNumResults() != 1)
299 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
304 auto width = type.getBitWidthOrSentinel();
309 SmallVector<Attribute, 3> constOperands;
310 constOperands.reserve(op->getNumOperands());
311 for (
auto operand : op->getOperands()) {
313 if (
auto *defOp = operand.getDefiningOp())
314 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
315 [&](
auto op) { attr = op.getValueAttr(); });
316 constOperands.push_back(attr);
321 auto result = canonicalize(constOperands);
325 if (
auto cst = dyn_cast<Attribute>(result))
326 resultValue = op->getDialect()
327 ->materializeConstant(rewriter, cst, type, op->getLoc())
330 resultValue = cast<Value>(result);
334 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
335 resultValue = PadPrimOp::create(rewriter, op->getLoc(), resultValue, width);
338 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
339 resultValue = AsSIntPrimOp::create(rewriter, op->getLoc(), resultValue);
340 else if (type_isa<UIntType>(type) &&
341 type_isa<SIntType>(resultValue.getType()))
342 resultValue = AsUIntPrimOp::create(rewriter, op->getLoc(), resultValue);
344 assert(type == resultValue.getType() &&
"canonicalization changed type");
352 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
358 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
364 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
371OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
372 assert(adaptor.getOperands().empty() &&
"constant has no operands");
373 return getValueAttr();
376OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
377 assert(adaptor.getOperands().empty() &&
"constant has no operands");
378 return getValueAttr();
381OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
382 assert(adaptor.getOperands().empty() &&
"constant has no operands");
383 return getFieldsAttr();
386OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
387 assert(adaptor.getOperands().empty() &&
"constant has no operands");
388 return getValueAttr();
391OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
392 assert(adaptor.getOperands().empty() &&
"constant has no operands");
393 return getValueAttr();
396OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
397 assert(adaptor.getOperands().empty() &&
"constant has no operands");
398 return getValueAttr();
401OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
402 assert(adaptor.getOperands().empty() &&
"constant has no operands");
403 return getValueAttr();
410OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
413 [=](
const APSInt &a,
const APSInt &b) { return a + b; });
416void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
418 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
419 patterns::AddOfSelf, patterns::AddOfPad>(
context);
422OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
425 [=](
const APSInt &a,
const APSInt &b) { return a - b; });
428void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
430 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
431 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
432 patterns::SubOfPadL, patterns::SubOfPadR>(
context);
435OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
447 [=](
const APSInt &a,
const APSInt &b) { return a * b; });
450OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
457 if (getLhs() == getRhs()) {
458 auto width = getType().base().getWidthOrSentinel();
463 return getIntAttr(getType(), APInt(width, 1));
480 if (
auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
481 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
486 [=](
const APSInt &a,
const APSInt &b) -> APInt {
489 return APInt(a.getBitWidth(), 0);
493OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
500 if (getLhs() == getRhs())
514 [=](
const APSInt &a,
const APSInt &b) -> APInt {
517 return APInt(a.getBitWidth(), 0);
521OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
524 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
527OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
530 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
533OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
536 [=](
const APSInt &a,
const APSInt &b) -> APInt {
537 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
543OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
546 if (rhsCst->isZero())
550 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
551 getRhs().getType() == getType())
557 if (lhsCst->isZero())
561 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
562 getRhs().getType() == getType())
567 if (getLhs() == getRhs() && getRhs().getType() == getType())
572 [](
const APSInt &a,
const APSInt &b) -> APInt { return a & b; });
575void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
578 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
579 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
580 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(
context);
583OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
586 if (rhsCst->isZero() && getLhs().getType() == getType())
590 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
591 getLhs().getType() == getType())
597 if (lhsCst->isZero() && getRhs().getType() == getType())
601 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
602 getRhs().getType() == getType())
607 if (getLhs() == getRhs() && getRhs().getType() == getType())
612 [](
const APSInt &a,
const APSInt &b) -> APInt { return a | b; });
615void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
617 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
618 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
622OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
625 if (rhsCst->isZero() &&
631 if (lhsCst->isZero() &&
636 if (getLhs() == getRhs())
639 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
643 [](
const APSInt &a,
const APSInt &b) -> APInt { return a ^ b; });
646void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
648 results.insert<patterns::extendXor, patterns::moveConstXor,
649 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
653void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
655 results.insert<patterns::LEQWithConstLHS>(
context);
658OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
659 bool isUnsigned = getLhs().getType().base().isUnsigned();
662 if (getLhs() == getRhs())
666 if (
auto width = getLhs().getType().base().
getWidth()) {
668 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
669 commonWidth = std::max(commonWidth, 1);
680 if (isUnsigned && rhsCst->zext(commonWidth)
693 [=](
const APSInt &a,
const APSInt &b) -> APInt {
694 return APInt(1, a <= b);
698void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
700 results.insert<patterns::LTWithConstLHS>(
context);
703OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
704 IntType lhsType = getLhs().getType();
708 if (getLhs() == getRhs())
718 if (
auto width = lhsType.
getWidth()) {
720 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
721 commonWidth = std::max(commonWidth, 1);
732 if (isUnsigned && rhsCst->zext(commonWidth)
745 [=](
const APSInt &a,
const APSInt &b) -> APInt {
746 return APInt(1, a < b);
750void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
752 results.insert<patterns::GEQWithConstLHS>(
context);
755OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
756 IntType lhsType = getLhs().getType();
760 if (getLhs() == getRhs())
765 if (rhsCst->isZero() && isUnsigned)
770 if (
auto width = lhsType.
getWidth()) {
772 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
773 commonWidth = std::max(commonWidth, 1);
776 if (isUnsigned && rhsCst->zext(commonWidth)
797 [=](
const APSInt &a,
const APSInt &b) -> APInt {
798 return APInt(1, a >= b);
802void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
804 results.insert<patterns::GTWithConstLHS>(
context);
807OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
808 IntType lhsType = getLhs().getType();
812 if (getLhs() == getRhs())
816 if (
auto width = lhsType.
getWidth()) {
818 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
819 commonWidth = std::max(commonWidth, 1);
822 if (isUnsigned && rhsCst->zext(commonWidth)
843 [=](
const APSInt &a,
const APSInt &b) -> APInt {
844 return APInt(1, a > b);
848OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
850 if (getLhs() == getRhs())
856 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
857 getRhs().getType() == getType())
863 [=](
const APSInt &a,
const APSInt &b) -> APInt {
864 return APInt(1, a == b);
868LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
870 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
872 auto width = op.getLhs().getType().getBitWidthOrSentinel();
875 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
876 op.getRhs().getType() == op.getType()) {
877 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
882 if (rhsCst->isZero() && width > 1) {
883 auto orrOp = OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
884 return NotPrimOp::create(rewriter, op.getLoc(), orrOp).getResult();
888 if (rhsCst->isAllOnes() && width > 1 &&
889 op.getLhs().getType() == op.getRhs().getType()) {
890 return AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
898OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
900 if (getLhs() == getRhs())
906 if (rhsCst->isZero() && getLhs().getType() == getType() &&
907 getRhs().getType() == getType())
913 [=](
const APSInt &a,
const APSInt &b) -> APInt {
914 return APInt(1, a != b);
918LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
920 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
922 auto width = op.getLhs().getType().getBitWidthOrSentinel();
925 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
926 op.getRhs().getType() == op.getType()) {
927 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
932 if (rhsCst->isZero() && width > 1) {
933 return OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
938 if (rhsCst->isAllOnes() && width > 1 &&
939 op.getLhs().getType() == op.getRhs().getType()) {
941 AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
942 return NotPrimOp::create(rewriter, op.getLoc(), andrOp).getResult();
950OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
956OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
962OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
966 return IntegerAttr::get(IntegerType::get(getContext(),
967 lhsCst->getBitWidth(),
968 IntegerType::Signed),
969 lhsCst->ashr(*rhsCst));
972 if (rhsCst->isZero())
979OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
984 return IntegerAttr::get(IntegerType::get(getContext(),
985 lhsCst->getBitWidth(),
986 IntegerType::Signed),
987 lhsCst->shl(*rhsCst));
990 if (rhsCst->isZero())
1001OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
1002 auto base = getInput().getType();
1009OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
1016OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
1024 if (getType().base().hasWidth())
1031void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1033 results.insert<patterns::StoUtoS>(
context);
1036OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1044 if (getType().base().hasWidth())
1051void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1053 results.insert<patterns::UtoStoU>(
context);
1056OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1058 if (getInput().getType() == getType())
1063 return BoolAttr::get(getContext(), cst->getBoolValue());
1068OpFoldResult AsResetPrimOp::fold(FoldAdaptor adaptor) {
1070 return BoolAttr::get(getContext(), cst->getBoolValue());
1074OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1076 if (getInput().getType() == getType())
1081 return BoolAttr::get(getContext(), cst->getBoolValue());
1086OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1092 getType().base().getWidthOrSentinel()))
1098void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1100 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(
context);
1103OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1110 getType().base().getWidthOrSentinel()))
1111 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1116OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1121 getType().base().getWidthOrSentinel()))
1127void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1129 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1130 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1137 : RewritePattern(opName, 0,
context) {}
1143 ConstantOp constantOp,
1144 SmallVectorImpl<Value> &remaining)
const = 0;
1151 mlir::PatternRewriter &rewriter)
const override {
1153 auto catOp = op->getOperand(0).getDefiningOp<CatPrimOp>();
1157 SmallVector<Value> nonConstantOperands;
1160 for (
auto operand : catOp.getInputs()) {
1161 if (
auto constantOp = operand.getDefiningOp<ConstantOp>()) {
1163 if (
handleConstant(rewriter, op, constantOp, nonConstantOperands))
1167 nonConstantOperands.push_back(operand);
1172 if (nonConstantOperands.empty()) {
1173 replaceOpWithNewOpAndCopyName<ConstantOp>(
1174 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1180 if (nonConstantOperands.size() == 1) {
1181 rewriter.modifyOpInPlace(
1182 op, [&] { op->setOperand(0, nonConstantOperands.front()); });
1187 if (catOp->hasOneUse() &&
1188 nonConstantOperands.size() < catOp->getNumOperands()) {
1189 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, catOp,
1190 nonConstantOperands);
1203 SmallVectorImpl<Value> &remaining)
const override {
1204 if (value.getValue().isZero())
1207 replaceOpWithNewOpAndCopyName<ConstantOp>(
1208 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1221 SmallVectorImpl<Value> &remaining)
const override {
1222 if (value.getValue().isAllOnes())
1225 replaceOpWithNewOpAndCopyName<ConstantOp>(
1226 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1239 SmallVectorImpl<Value> &remaining)
const override {
1240 if (value.getValue().isZero())
1242 remaining.push_back(value);
1248OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1252 if (getInput().getType().getBitWidthOrSentinel() == 0)
1257 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1261 if (
isUInt1(getInput().getType()))
1267void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1269 results.insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1270 patterns::AndRPadS, patterns::AndRCatAndR_left,
1274OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1278 if (getInput().getType().getBitWidthOrSentinel() == 0)
1283 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1287 if (
isUInt1(getInput().getType()))
1293void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1295 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1296 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right,
OrRCat>(
1300OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1304 if (getInput().getType().getBitWidthOrSentinel() == 0)
1309 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1312 if (
isUInt1(getInput().getType()))
1318void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1321 .insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1322 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right,
XorRCat>(
1330OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1331 auto inputs = getInputs();
1332 auto inputAdaptors = adaptor.getInputs();
1339 if (inputs.size() == 1 && inputs[0].getType() == getType())
1347 SmallVector<Value> nonZeroInputs;
1348 SmallVector<Attribute> nonZeroAttributes;
1349 bool allConstant =
true;
1350 for (
auto [input, attr] :
llvm::zip(inputs, inputAdaptors)) {
1351 auto inputType = type_cast<IntType>(input.getType());
1352 if (inputType.getBitWidthOrSentinel() != 0) {
1353 nonZeroInputs.push_back(input);
1355 allConstant =
false;
1356 if (nonZeroInputs.size() > 1 && !allConstant)
1362 if (nonZeroInputs.empty())
1366 if (nonZeroInputs.size() == 1 && nonZeroInputs[0].getType() == getType())
1367 return nonZeroInputs[0];
1373 SmallVector<APInt> constants;
1374 for (
auto inputAdaptor : inputAdaptors) {
1376 constants.push_back(*cst);
1381 assert(!constants.empty());
1383 APInt result = constants[0];
1384 for (
size_t i = 1; i < constants.size(); ++i)
1385 result = result.concat(constants[i]);
1390void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1392 results.insert<patterns::DShlOfConstant>(
context);
1395void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1397 results.insert<patterns::DShrOfConstant>(
context);
1405 using OpRewritePattern::OpRewritePattern;
1408 matchAndRewrite(CatPrimOp cat,
1409 mlir::PatternRewriter &rewriter)
const override {
1411 cat.getType().getBitWidthOrSentinel() == 0)
1415 if (cat->hasOneUse() && isa<CatPrimOp>(*cat->getUsers().begin()))
1419 SmallVector<Value> operands;
1420 SmallVector<Value> worklist;
1421 auto pushOperands = [&worklist](CatPrimOp op) {
1422 for (
auto operand :
llvm::reverse(op.getInputs()))
1423 worklist.push_back(operand);
1426 bool hasSigned =
false, hasUnsigned =
false;
1427 while (!worklist.empty()) {
1428 auto value = worklist.pop_back_val();
1429 auto catOp = value.getDefiningOp<CatPrimOp>();
1431 operands.push_back(value);
1432 (type_isa<UIntType>(value.getType()) ? hasUnsigned : hasSigned) =
true;
1436 pushOperands(catOp);
1441 auto castToUIntIfSigned = [&](Value value) -> Value {
1442 if (type_isa<UIntType>(value.getType()))
1444 return AsUIntPrimOp::create(rewriter, value.getLoc(), value);
1447 assert(operands.size() >= 1 &&
"zero width cast must be rejected");
1449 if (operands.size() == 1) {
1450 rewriter.replaceOp(cat, castToUIntIfSigned(operands[0]));
1454 if (operands.size() == cat->getNumOperands())
1458 if (hasSigned && hasUnsigned)
1459 for (
auto &operand : operands)
1460 operand = castToUIntIfSigned(operand);
1462 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, cat, cat.getType(),
1471 using OpRewritePattern::OpRewritePattern;
1474 matchAndRewrite(CatPrimOp cat,
1475 mlir::PatternRewriter &rewriter)
const override {
1479 SmallVector<Value> operands;
1481 for (
size_t i = 0; i < cat->getNumOperands(); ++i) {
1482 auto cst = cat.getInputs()[i].getDefiningOp<ConstantOp>();
1484 operands.push_back(cat.getInputs()[i]);
1487 APSInt value = cst.getValue();
1489 for (; j < cat->getNumOperands(); ++j) {
1490 auto nextCst = cat.getInputs()[j].getDefiningOp<ConstantOp>();
1493 value = value.concat(nextCst.getValue());
1498 operands.push_back(cst);
1501 operands.push_back(ConstantOp::create(rewriter, cat.getLoc(), value));
1507 if (operands.size() == cat->getNumOperands())
1510 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, cat, cat.getType(),
1519void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1521 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1522 patterns::CatCast, FlattenCat, CatOfConstant>(
context);
1529OpFoldResult StringConcatOp::fold(FoldAdaptor adaptor) {
1531 if (getInputs().size() == 1)
1532 return getInputs()[0];
1535 if (!llvm::all_of(adaptor.getInputs(), [](Attribute operand) {
1536 return isa_and_nonnull<StringAttr>(operand);
1541 SmallString<64> result;
1542 for (
auto operand : adaptor.getInputs())
1543 result += cast<StringAttr>(operand).getValue();
1545 return StringAttr::get(getContext(), result);
1553 using OpRewritePattern::OpRewritePattern;
1556 matchAndRewrite(StringConcatOp concat,
1557 mlir::PatternRewriter &rewriter)
const override {
1561 bool hasNestedConcat = llvm::any_of(concat.getInputs(), [](Value operand) {
1562 auto nestedConcat = operand.getDefiningOp<StringConcatOp>();
1563 return nestedConcat && operand.hasOneUse();
1566 if (!hasNestedConcat)
1570 SmallVector<Value> flatOperands;
1571 for (
auto input : concat.getInputs()) {
1572 if (
auto nestedConcat = input.getDefiningOp<StringConcatOp>();
1573 nestedConcat && input.hasOneUse())
1574 llvm::append_range(flatOperands, nestedConcat.getInputs());
1576 flatOperands.push_back(input);
1579 rewriter.modifyOpInPlace(concat,
1580 [&]() { concat->setOperands(flatOperands); });
1587class MergeAdjacentStringConstants
1590 using OpRewritePattern::OpRewritePattern;
1593 matchAndRewrite(StringConcatOp concat,
1594 mlir::PatternRewriter &rewriter)
const override {
1596 SmallVector<Value> newOperands;
1597 SmallString<64> accumulatedLit;
1598 SmallVector<StringConstantOp> accumulatedOps;
1599 bool changed =
false;
1601 auto flushLiterals = [&]() {
1602 if (accumulatedOps.empty())
1606 if (accumulatedOps.size() == 1) {
1607 newOperands.push_back(accumulatedOps[0]);
1610 auto newLit = rewriter.createOrFold<StringConstantOp>(
1611 concat.getLoc(), StringAttr::get(getContext(), accumulatedLit));
1612 newOperands.push_back(newLit);
1615 accumulatedLit.clear();
1616 accumulatedOps.clear();
1619 for (
auto operand : concat.getInputs()) {
1620 if (
auto litOp = operand.getDefiningOp<StringConstantOp>()) {
1622 if (litOp.getValue().empty()) {
1626 accumulatedLit += litOp.getValue();
1627 accumulatedOps.push_back(litOp);
1630 newOperands.push_back(operand);
1641 if (newOperands.empty())
1642 return rewriter.replaceOpWithNewOp<StringConstantOp>(
1643 concat, StringAttr::get(getContext(),
"")),
1647 rewriter.modifyOpInPlace(concat,
1648 [&]() { concat->setOperands(newOperands); });
1655void StringConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
1657 results.insert<FlattenStringConcat, MergeAdjacentStringConstants>(
context);
1664OpFoldResult PropEqOp::fold(FoldAdaptor adaptor) {
1665 auto lhsAttr = adaptor.getLhs();
1666 auto rhsAttr = adaptor.getRhs();
1667 if (!lhsAttr || !rhsAttr)
1670 return BoolAttr::get(getContext(), lhsAttr == rhsAttr);
1679 if (
auto boolAttr = dyn_cast_or_null<BoolAttr>(attr))
1680 return boolAttr.getValue();
1681 return std::nullopt;
1684OpFoldResult BoolAndOp::fold(FoldAdaptor adaptor) {
1688 return BoolAttr::get(getContext(), *lhs && *rhs);
1690 if ((lhs && !*lhs) || (rhs && !*rhs))
1691 return BoolAttr::get(getContext(),
false);
1700OpFoldResult BoolOrOp::fold(FoldAdaptor adaptor) {
1704 return BoolAttr::get(getContext(), *lhs || *rhs);
1706 if ((lhs && *lhs) || (rhs && *rhs))
1707 return BoolAttr::get(getContext(),
true);
1716OpFoldResult BoolXorOp::fold(FoldAdaptor adaptor) {
1720 return BoolAttr::get(getContext(), *lhs ^ *rhs);
1729OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1732 if (op.getType() == op.getInput().getType())
1733 return op.getInput();
1737 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1738 if (op.getType() == in.getInput().getType())
1739 return in.getInput();
1744OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1745 IntType inputType = getInput().getType();
1746 IntType resultType = getType();
1748 if (inputType == getType() && resultType.
hasWidth())
1755 cst->extractBits(getHi() - getLo() + 1, getLo()));
1761 using OpRewritePattern::OpRewritePattern;
1765 mlir::PatternRewriter &rewriter)
const override {
1766 auto cat = bits.getInput().getDefiningOp<CatPrimOp>();
1769 int32_t bitPos = bits.getLo();
1770 auto resultWidth = type_cast<UIntType>(bits.getType()).getWidthOrSentinel();
1771 if (resultWidth < 0)
1773 for (
auto operand : llvm::reverse(cat.getInputs())) {
1775 type_cast<IntType>(operand.getType()).getWidthOrSentinel();
1776 if (operandWidth < 0)
1778 if (bitPos < operandWidth) {
1779 if (bitPos + resultWidth <= operandWidth) {
1780 auto newBits = rewriter.createOrFold<BitsPrimOp>(
1781 bits.getLoc(), operand, bitPos + resultWidth - 1, bitPos);
1787 bitPos -= operandWidth;
1793void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1796 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1804 unsigned loBit, PatternRewriter &rewriter) {
1805 auto resType = type_cast<IntType>(op->getResult(0).getType());
1806 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1807 value = BitsPrimOp::create(rewriter, op->getLoc(), value, hiBit, loBit);
1809 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1810 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1811 }
else if (resType.isUnsigned() &&
1812 !type_cast<IntType>(value.getType()).isUnsigned()) {
1813 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1815 rewriter.replaceOp(op, value);
1818template <
typename OpTy>
1819static OpFoldResult
foldMux(OpTy op,
typename OpTy::FoldAdaptor adaptor) {
1821 if (op.getType().getBitWidthOrSentinel() == 0)
1823 APInt(0, 0, op.getType().isSignedInteger()));
1826 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1827 return op.getHigh();
1832 if (op.getType().getBitWidthOrSentinel() < 0)
1837 if (cond->isZero() && op.getLow().getType() == op.getType())
1839 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1840 return op.getHigh();
1844 if (
auto lowCst =
getConstant(adaptor.getLow())) {
1846 if (
auto highCst =
getConstant(adaptor.getHigh())) {
1848 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1849 *highCst == *lowCst)
1852 if (
auto intType = type_dyn_cast<IntType>(op.getType()))
1853 if (intType.hasWidth() &&
1854 (
unsigned)intType.getWidthOrSentinel() == highCst->getBitWidth())
1857 if (highCst->isOne() && lowCst->isZero() &&
1858 op.getType() == op.getSel().getType())
1871OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1872 return foldMux(*
this, adaptor);
1875OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1876 return foldMux(*
this, adaptor);
1879OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) {
return {}; }
1888 using OpRewritePattern::OpRewritePattern;
1891 matchAndRewrite(MuxPrimOp mux,
1892 mlir::PatternRewriter &rewriter)
const override {
1893 auto width = mux.getType().getBitWidthOrSentinel();
1897 auto pad = [&](Value input) -> Value {
1899 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1900 if (inputWidth < 0 || width == inputWidth)
1902 return PadPrimOp::create(rewriter, mux.getLoc(), mux.getType(), input,
1907 auto newHigh = pad(mux.getHigh());
1908 auto newLow = pad(mux.getLow());
1909 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1912 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1913 rewriter, mux, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1923 using OpRewritePattern::OpRewritePattern;
1925 static const int depthLimit = 5;
1927 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1928 mlir::PatternRewriter &rewriter,
1929 bool updateInPlace)
const {
1930 if (updateInPlace) {
1931 rewriter.modifyOpInPlace(mux, [&] {
1932 mux.setOperand(1, high);
1933 mux.setOperand(2, low);
1937 rewriter.setInsertionPointAfter(mux);
1938 return MuxPrimOp::create(rewriter, mux.getLoc(), mux.getType(),
1939 ValueRange{mux.getSel(), high, low})
1944 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1945 bool updateInPlace,
int limit)
const {
1946 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1949 if (mux.getSel() == cond)
1950 return mux.getHigh();
1951 if (limit > depthLimit)
1953 updateInPlace &= mux->hasOneUse();
1955 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1957 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1960 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1961 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1966 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1967 bool updateInPlace,
int limit)
const {
1968 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1971 if (mux.getSel() == cond)
1972 return mux.getLow();
1973 if (limit > depthLimit)
1975 updateInPlace &= mux->hasOneUse();
1977 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1979 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1981 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1983 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1989 matchAndRewrite(MuxPrimOp mux,
1990 mlir::PatternRewriter &rewriter)
const override {
1991 auto width = mux.getType().getBitWidthOrSentinel();
1995 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter,
true, 0)) {
1996 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
2000 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter,
true, 0)) {
2001 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
2010void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
2013 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
2014 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
2015 patterns::MuxSameTrue, patterns::MuxSameFalse,
2016 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
2020void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
2021 RewritePatternSet &results, MLIRContext *
context) {
2022 results.add<patterns::Mux2PadSel>(
context);
2025void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
2026 RewritePatternSet &results, MLIRContext *
context) {
2027 results.add<patterns::Mux4PadSel>(
context);
2030OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
2031 auto input = this->getInput();
2034 if (input.getType() == getType())
2038 auto inputType = input.getType().base();
2045 auto destWidth = getType().base().getWidthOrSentinel();
2046 if (destWidth == -1)
2049 if (inputType.
isSigned() && cst->getBitWidth())
2050 return getIntAttr(getType(), cst->sext(destWidth));
2051 return getIntAttr(getType(), cst->zext(destWidth));
2057OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
2058 auto input = this->getInput();
2059 IntType inputType = input.getType();
2060 int shiftAmount = getAmount();
2063 if (shiftAmount == 0)
2069 if (inputWidth != -1) {
2070 auto resultWidth = inputWidth + shiftAmount;
2071 shiftAmount = std::min(shiftAmount, resultWidth);
2072 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
2078OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
2079 auto input = this->getInput();
2080 IntType inputType = input.getType();
2081 int shiftAmount = getAmount();
2087 if (shiftAmount == 0 && inputWidth > 0)
2090 if (inputWidth == -1)
2092 if (inputWidth == 0)
2097 if (shiftAmount >= inputWidth && inputType.
isUnsigned())
2098 return getIntAttr(getType(), APInt(0, 0,
false));
2104 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
2106 value = cst->lshr(std::min(shiftAmount, inputWidth));
2107 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
2108 return getIntAttr(getType(), value.trunc(resultWidth));
2113LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
2114 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2115 if (inputWidth <= 0)
2119 unsigned shiftAmount = op.getAmount();
2120 if (
int(shiftAmount) >= inputWidth) {
2122 if (op.getType().base().isUnsigned())
2128 shiftAmount = inputWidth - 1;
2131 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
2135LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
2136 PatternRewriter &rewriter) {
2137 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2138 if (inputWidth <= 0)
2142 unsigned keepAmount = op.getAmount();
2144 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
2149OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
2153 getInput().getType().base().getWidthOrSentinel() - getAmount();
2154 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
2160OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
2164 cst->trunc(getType().base().getWidthOrSentinel()));
2168LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
2169 PatternRewriter &rewriter) {
2170 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2171 if (inputWidth <= 0)
2175 unsigned dropAmount = op.getAmount();
2176 if (dropAmount !=
unsigned(inputWidth))
2182void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
2184 results.add<patterns::SubaccessOfConstant>(
context);
2187OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
2189 if (adaptor.getInputs().size() == 1)
2190 return getOperand(1);
2192 if (
auto constIndex =
getConstant(adaptor.getIndex())) {
2193 auto index = constIndex->getZExtValue();
2194 if (index < getInputs().size())
2195 return getInputs()[getInputs().size() - 1 - index];
2201LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
2202 PatternRewriter &rewriter) {
2206 if (llvm::all_of(op.getInputs().drop_front(), [&](
auto input) {
2207 return input == op.getInputs().front();
2215 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
2216 uint64_t inputSize = op.getInputs().size();
2217 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
2218 rewriter.modifyOpInPlace(op, [&]() {
2219 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
2226 if (
auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
2227 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](
auto e) {
2228 auto subindex = e.value().template getDefiningOp<SubindexOp>();
2229 return subindex && lastSubindex.getInput() == subindex.getInput() &&
2230 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
2232 replaceOpWithNewOpAndCopyName<SubaccessOp>(
2233 rewriter, op, lastSubindex.getInput(), op.getIndex());
2239 if (op.getInputs().size() != 2)
2243 auto uintType = op.getIndex().getType();
2244 if (uintType.getBitWidthOrSentinel() != 1)
2248 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
2249 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
2268 MatchingConnectOp connect;
2269 for (Operation *user : value.getUsers()) {
2271 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
2274 if (
auto aConnect = dyn_cast<FConnectLike>(user))
2275 if (aConnect.getDest() == value) {
2276 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
2279 if (!matchingConnect || (connect && connect != matchingConnect) ||
2280 matchingConnect->getBlock() != value.getParentBlock())
2282 connect = matchingConnect;
2290 PatternRewriter &rewriter) {
2293 Operation *connectedDecl = op.getDest().getDefiningOp();
2298 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
2302 cast<Forceable>(connectedDecl).isForceable())
2310 if (connectedDecl->hasOneUse())
2314 auto *declBlock = connectedDecl->getBlock();
2315 auto *srcValueOp = op.getSrc().getDefiningOp();
2318 if (!isa<WireOp>(connectedDecl))
2324 if (!isa<ConstantOp>(srcValueOp))
2326 if (srcValueOp->getBlock() != declBlock)
2332 auto replacement = op.getSrc();
2335 if (srcValueOp && srcValueOp != &declBlock->front())
2336 srcValueOp->moveBefore(&declBlock->front());
2343 rewriter.eraseOp(op);
2347void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2349 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
2353LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
2354 PatternRewriter &rewriter) {
2371 for (
auto *user : value.getUsers()) {
2372 auto attach = dyn_cast<AttachOp>(user);
2373 if (!attach || attach == dominatedAttach)
2375 if (attach->isBeforeInBlock(dominatedAttach))
2381LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
2383 if (op.getNumOperands() <= 1) {
2384 rewriter.eraseOp(op);
2388 for (
auto operand : op.getOperands()) {
2395 SmallVector<Value> newOperands(op.getOperands());
2396 for (
auto newOperand : attach.getOperands())
2397 if (newOperand != operand)
2398 newOperands.push_back(newOperand);
2399 AttachOp::create(rewriter, op->getLoc(), newOperands);
2400 rewriter.eraseOp(attach);
2401 rewriter.eraseOp(op);
2409 if (
auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
2410 if (!
hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
2411 !wire.isForceable()) {
2412 SmallVector<Value> newOperands;
2413 for (
auto newOperand : op.getOperands())
2414 if (newOperand != operand)
2415 newOperands.push_back(newOperand);
2417 AttachOp::create(rewriter, op->getLoc(), newOperands);
2418 rewriter.eraseOp(op);
2419 rewriter.eraseOp(wire);
2430 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
2431 rewriter.inlineBlockBefore(®ion.front(), op, {});
2434LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
2435 if (
auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
2436 if (constant.getValue().isAllOnes())
2438 else if (op.hasElseRegion() && !op.getElseRegion().empty())
2441 rewriter.eraseOp(op);
2447 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
2448 op.getElseBlock().empty()) {
2449 rewriter.eraseBlock(&op.getElseBlock());
2456 if (!op.getThenBlock().empty())
2460 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
2461 rewriter.eraseOp(op);
2471 using OpRewritePattern::OpRewritePattern;
2472 LogicalResult matchAndRewrite(NodeOp node,
2473 PatternRewriter &rewriter)
const override {
2474 auto name = node.getNameAttr();
2475 if (!node.hasDroppableName() || node.getInnerSym() ||
2478 auto *newOp = node.getInput().getDefiningOp();
2481 rewriter.replaceOp(node, node.getInput());
2488 using OpRewritePattern::OpRewritePattern;
2489 LogicalResult matchAndRewrite(NodeOp node,
2490 PatternRewriter &rewriter)
const override {
2492 node.use_empty() || node.isForceable())
2494 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
2501template <
typename OpTy>
2503 PatternRewriter &rewriter) {
2504 if (!op.isForceable() || !op.getDataRef().use_empty())
2512LogicalResult NodeOp::fold(FoldAdaptor adaptor,
2513 SmallVectorImpl<OpFoldResult> &results) {
2522 if (!adaptor.getInput())
2525 results.push_back(adaptor.getInput());
2529void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2531 results.insert<FoldNodeName>(
context);
2532 results.add(demoteForceableIfUnused<NodeOp>);
2538struct AggOneShot :
public mlir::RewritePattern {
2539 AggOneShot(StringRef name, uint32_t weight, MLIRContext *
context)
2540 : RewritePattern(name, 0,
context) {}
2542 SmallVector<Value> getCompleteWrite(Operation *lhs)
const {
2543 auto lhsTy = lhs->getResult(0).getType();
2544 if (!type_isa<BundleType, FVectorType>(lhsTy))
2547 DenseMap<uint32_t, Value> fields;
2548 for (Operation *user : lhs->getResult(0).getUsers()) {
2549 if (user->getParentOp() != lhs->getParentOp())
2551 if (
auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2552 if (aConnect.getDest() == lhs->getResult(0))
2554 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2555 for (Operation *subuser : subField.getResult().getUsers()) {
2556 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2557 if (aConnect.getDest() == subField) {
2558 if (subuser->getParentOp() != lhs->getParentOp())
2560 if (fields.count(subField.getFieldIndex()))
2562 fields[subField.getFieldIndex()] = aConnect.getSrc();
2568 }
else if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2569 for (Operation *subuser : subIndex.getResult().getUsers()) {
2570 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2571 if (aConnect.getDest() == subIndex) {
2572 if (subuser->getParentOp() != lhs->getParentOp())
2574 if (fields.count(subIndex.getIndex()))
2576 fields[subIndex.getIndex()] = aConnect.getSrc();
2587 SmallVector<Value> values;
2588 uint32_t total = type_isa<BundleType>(lhsTy)
2589 ? type_cast<BundleType>(lhsTy).getNumElements()
2590 : type_cast<FVectorType>(lhsTy).getNumElements();
2591 for (uint32_t i = 0; i < total; ++i) {
2592 if (!fields.count(i))
2594 values.push_back(fields[i]);
2599 LogicalResult matchAndRewrite(Operation *op,
2600 PatternRewriter &rewriter)
const override {
2601 auto values = getCompleteWrite(op);
2604 rewriter.setInsertionPointToEnd(op->getBlock());
2605 auto dest = op->getResult(0);
2606 auto destType = dest.getType();
2609 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2612 Value newVal = type_isa<BundleType>(destType)
2613 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2615 : rewriter.createOrFold<VectorCreateOp>(
2616 op->
getLoc(), destType, values);
2617 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2618 for (Operation *user : dest.getUsers()) {
2619 if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2620 for (Operation *subuser :
2621 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2622 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2623 if (aConnect.getDest() == subIndex)
2624 rewriter.eraseOp(aConnect);
2625 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2626 for (Operation *subuser :
2627 llvm::make_early_inc_range(subField.getResult().getUsers()))
2628 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2629 if (aConnect.getDest() == subField)
2630 rewriter.eraseOp(aConnect);
2637struct WireAggOneShot :
public AggOneShot {
2638 WireAggOneShot(MLIRContext *
context)
2639 : AggOneShot(WireOp::getOperationName(), 0,
context) {}
2641struct SubindexAggOneShot :
public AggOneShot {
2642 SubindexAggOneShot(MLIRContext *
context)
2643 : AggOneShot(SubindexOp::getOperationName(), 0,
context) {}
2645struct SubfieldAggOneShot :
public AggOneShot {
2646 SubfieldAggOneShot(MLIRContext *
context)
2647 : AggOneShot(SubfieldOp::getOperationName(), 0,
context) {}
2651void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2653 results.insert<WireAggOneShot>(
context);
2654 results.add(demoteForceableIfUnused<WireOp>);
2657void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2659 results.insert<SubindexAggOneShot>(
context);
2662OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2663 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2666 return attr[getIndex()];
2669OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2670 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2673 auto index = getFieldIndex();
2677void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2679 results.insert<SubfieldAggOneShot>(
context);
2683 ArrayRef<Attribute> operands) {
2684 for (
auto operand : operands)
2687 return ArrayAttr::get(
context, operands);
2690OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2693 if (getNumOperands() > 0)
2694 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2695 if (first.getFieldIndex() == 0 &&
2696 first.getInput().getType() == getType() &&
2698 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2700 elem.value().
template getDefiningOp<SubfieldOp>();
2701 return subindex && subindex.getInput() == first.getInput() &&
2702 subindex.getFieldIndex() == elem.index();
2704 return first.getInput();
2709OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2712 if (getNumOperands() > 0)
2713 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2714 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2716 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2718 elem.value().
template getDefiningOp<SubindexOp>();
2719 return subindex && subindex.getInput() == first.getInput() &&
2720 subindex.getIndex() == elem.index();
2722 return first.getInput();
2727OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2728 if (getOperand().getType() == getType())
2729 return getOperand();
2737 using OpRewritePattern::OpRewritePattern;
2738 LogicalResult matchAndRewrite(RegResetOp reg,
2739 PatternRewriter &rewriter)
const override {
2741 dyn_cast_or_null<ConstantOp>(
reg.getResetValue().getDefiningOp());
2750 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2753 auto *high = mux.getHigh().getDefiningOp();
2754 auto *low = mux.getLow().getDefiningOp();
2755 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2757 if (constOp && low != reg)
2759 if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2760 constOp = dyn_cast<ConstantOp>(low);
2762 if (!constOp || constOp.getType() != reset.getType() ||
2763 constOp.getValue() != reset.getValue())
2767 auto regTy =
reg.getResult().getType();
2768 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2769 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2770 regTy.getBitWidthOrSentinel() < 0)
2776 if (constOp != &con->getBlock()->front())
2777 constOp->moveBefore(&con->getBlock()->front());
2782 rewriter.eraseOp(con);
2789 if (
auto c = v.getDefiningOp<ConstantOp>())
2790 return c.getValue().isOne();
2791 if (
auto sc = v.getDefiningOp<SpecialConstantOp>())
2792 return sc.getValue();
2801 auto resetValue = reg.getResetValue();
2802 if (reg.getType(0) != resetValue.getType())
2806 (void)
dropWrite(rewriter, reg->getResult(0), {});
2807 replaceOpWithNewOpAndCopyName<NodeOp>(
2808 rewriter, reg, resetValue, reg.getNameAttr(), reg.getNameKind(),
2809 reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2813void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2815 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(
context);
2817 results.add(demoteForceableIfUnused<RegResetOp>);
2822 auto portTy = type_cast<BundleType>(port.getType());
2823 auto fieldIndex = portTy.getElementIndex(name);
2824 assert(fieldIndex &&
"missing field on memory port");
2827 for (
auto *op : port.getUsers()) {
2828 auto portAccess = cast<SubfieldOp>(op);
2829 if (fieldIndex != portAccess.getFieldIndex())
2834 value = conn.getSrc();
2844 auto portConst = value.getDefiningOp<ConstantOp>();
2847 return portConst.getValue().isZero();
2852 auto portTy = type_cast<BundleType>(port.getType());
2853 auto fieldIndex = portTy.getElementIndex(
data);
2854 assert(fieldIndex &&
"missing enable flag on memory port");
2856 for (
auto *op : port.getUsers()) {
2857 auto portAccess = cast<SubfieldOp>(op);
2858 if (fieldIndex != portAccess.getFieldIndex())
2860 if (!portAccess.use_empty())
2869 StringRef name, Value value) {
2870 auto portTy = type_cast<BundleType>(port.getType());
2871 auto fieldIndex = portTy.getElementIndex(name);
2872 assert(fieldIndex &&
"missing field on memory port");
2874 for (
auto *op : llvm::make_early_inc_range(port.getUsers())) {
2875 auto portAccess = cast<SubfieldOp>(op);
2876 if (fieldIndex != portAccess.getFieldIndex())
2878 rewriter.replaceAllUsesWith(portAccess, value);
2879 rewriter.eraseOp(portAccess);
2884static void erasePort(PatternRewriter &rewriter, Value port) {
2887 auto getClock = [&] {
2889 clock = SpecialConstantOp::create(rewriter, port.getLoc(),
2890 ClockType::get(rewriter.getContext()),
2899 for (
auto *op : port.getUsers()) {
2900 auto subfield = dyn_cast<SubfieldOp>(op);
2902 auto ty = port.getType();
2903 auto reg = RegOp::create(rewriter, port.getLoc(), ty, getClock());
2904 rewriter.replaceAllUsesWith(port, reg.getResult());
2913 for (
auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2914 auto access = cast<SubfieldOp>(accessOp);
2915 for (
auto *user : llvm::make_early_inc_range(access->getUsers())) {
2916 auto connect = dyn_cast<FConnectLike>(user);
2917 if (connect && connect.getDest() == access) {
2918 rewriter.eraseOp(user);
2922 if (access.use_empty()) {
2923 rewriter.eraseOp(access);
2929 auto ty = access.getType();
2930 auto reg = RegOp::create(rewriter, access.getLoc(), ty, getClock());
2931 rewriter.replaceOp(access, reg.getResult());
2933 assert(port.use_empty() &&
"port should have no remaining uses");
2939 using OpRewritePattern::OpRewritePattern;
2940 LogicalResult matchAndRewrite(MemOp mem,
2941 PatternRewriter &rewriter)
const override {
2945 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2946 mem.getDataType().getBitWidthOrSentinel() != 0)
2950 for (
auto port : mem.getResults())
2951 for (auto *user : port.getUsers())
2952 if (!isa<SubfieldOp>(user))
2957 for (
auto port : mem.getResults()) {
2958 for (
auto *user :
llvm::make_early_inc_range(port.getUsers())) {
2959 SubfieldOp sfop = cast<SubfieldOp>(user);
2960 StringRef fieldName = sfop.getFieldName();
2961 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2962 rewriter, sfop, sfop.getResult().getType())
2964 if (fieldName.ends_with(
"data")) {
2966 auto zero = firrtl::ConstantOp::create(
2967 rewriter, wire.getLoc(),
2968 firrtl::type_cast<IntType>(wire.getType()), APInt::getZero(0));
2969 MatchingConnectOp::create(rewriter, wire.getLoc(), wire, zero);
2973 rewriter.eraseOp(mem);
2980 using OpRewritePattern::OpRewritePattern;
2981 LogicalResult matchAndRewrite(MemOp mem,
2982 PatternRewriter &rewriter)
const override {
2985 bool isRead =
false, isWritten =
false;
2986 for (
unsigned i = 0; i < mem.getNumResults(); ++i) {
2987 switch (mem.getPortKind(i)) {
2988 case MemOp::PortKind::Read:
2993 case MemOp::PortKind::Write:
2998 case MemOp::PortKind::Debug:
2999 case MemOp::PortKind::ReadWrite:
3002 llvm_unreachable(
"unknown port kind");
3004 assert((!isWritten || !isRead) &&
"memory is in use");
3009 if (isRead && mem.getInit())
3012 for (
auto port : mem.getResults())
3015 rewriter.eraseOp(mem);
3022 using OpRewritePattern::OpRewritePattern;
3023 LogicalResult matchAndRewrite(MemOp mem,
3024 PatternRewriter &rewriter)
const override {
3028 llvm::SmallBitVector deadPorts(mem.getNumResults());
3029 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3031 if (!mem.getPortAnnotation(i).empty())
3035 auto kind = mem.getPortKind(i);
3036 if (kind == MemOp::PortKind::Debug)
3045 if (kind == MemOp::PortKind::Read &&
isPortUnused(port,
"data")) {
3050 if (deadPorts.none())
3054 SmallVector<Type> resultTypes;
3055 SmallVector<StringRef> portNames;
3056 SmallVector<Attribute> portAnnotations;
3057 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3060 resultTypes.push_back(port.getType());
3061 portNames.push_back(mem.getPortName(i));
3062 portAnnotations.push_back(mem.getPortAnnotation(i));
3066 if (!resultTypes.empty())
3067 newOp = MemOp::create(
3068 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
3069 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
3070 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
3071 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
3072 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3075 unsigned nextPort = 0;
3076 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3080 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
3083 rewriter.eraseOp(mem);
3090 using OpRewritePattern::OpRewritePattern;
3091 LogicalResult matchAndRewrite(MemOp mem,
3092 PatternRewriter &rewriter)
const override {
3097 llvm::SmallBitVector deadReads(mem.getNumResults());
3098 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3099 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
3101 if (!mem.getPortAnnotation(i).empty())
3108 if (deadReads.none())
3111 SmallVector<Type> resultTypes;
3112 SmallVector<StringRef> portNames;
3113 SmallVector<Attribute> portAnnotations;
3114 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3116 resultTypes.push_back(
3117 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
3118 MemOp::PortKind::Write, mem.getMaskBits()));
3120 resultTypes.push_back(port.getType());
3122 portNames.push_back(mem.getPortName(i));
3123 portAnnotations.push_back(mem.getPortAnnotation(i));
3126 auto newOp = MemOp::create(
3127 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
3128 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
3129 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
3130 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
3131 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3133 for (
unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
3134 auto result = mem.getResult(i);
3135 auto newResult = newOp.getResult(i);
3137 auto resultPortTy = type_cast<BundleType>(result.getType());
3141 auto replace = [&](StringRef toName, StringRef fromName) {
3142 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
3143 assert(fromFieldIndex &&
"missing enable flag on memory port");
3145 auto toField = SubfieldOp::create(rewriter, newResult.getLoc(),
3147 for (
auto *op :
llvm::make_early_inc_range(result.getUsers())) {
3148 auto fromField = cast<SubfieldOp>(op);
3149 if (fromFieldIndex != fromField.getFieldIndex())
3151 rewriter.replaceOp(fromField, toField.getResult());
3155 replace(
"addr",
"addr");
3156 replace(
"en",
"en");
3157 replace(
"clk",
"clk");
3158 replace(
"data",
"wdata");
3159 replace(
"mask",
"wmask");
3162 auto wmodeFieldIndex = resultPortTy.getElementIndex(
"wmode");
3163 for (
auto *op :
llvm::make_early_inc_range(result.getUsers())) {
3164 auto wmodeField = cast<SubfieldOp>(op);
3165 if (wmodeFieldIndex != wmodeField.getFieldIndex())
3167 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
3170 rewriter.replaceAllUsesWith(result, newResult);
3173 rewriter.eraseOp(mem);
3180 using OpRewritePattern::OpRewritePattern;
3182 LogicalResult matchAndRewrite(MemOp mem,
3183 PatternRewriter &rewriter)
const override {
3188 const auto &summary = mem.getSummary();
3189 if (summary.isMasked || summary.isSeqMem())
3192 auto type = type_dyn_cast<IntType>(mem.getDataType());
3195 auto width = type.getBitWidthOrSentinel();
3199 llvm::SmallBitVector usedBits(width);
3200 DenseMap<unsigned, unsigned> mapping;
3205 SmallVector<BitsPrimOp> readOps;
3206 auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
3207 auto portTy = type_cast<BundleType>(port.getType());
3208 auto fieldIndex = portTy.getElementIndex(field);
3209 assert(fieldIndex &&
"missing data port");
3211 for (
auto *op : port.getUsers()) {
3212 auto portAccess = cast<SubfieldOp>(op);
3213 if (fieldIndex != portAccess.getFieldIndex())
3216 for (
auto *user : op->getUsers()) {
3217 auto bits = dyn_cast<BitsPrimOp>(user);
3221 usedBits.set(bits.getLo(), bits.getHi() + 1);
3225 mapping[bits.getLo()] = 0;
3226 readOps.push_back(bits);
3236 SmallVector<MatchingConnectOp> writeOps;
3237 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
3238 auto portTy = type_cast<BundleType>(port.getType());
3239 auto fieldIndex = portTy.getElementIndex(field);
3240 assert(fieldIndex &&
"missing data port");
3242 for (
auto *op : port.getUsers()) {
3243 auto portAccess = cast<SubfieldOp>(op);
3244 if (fieldIndex != portAccess.getFieldIndex())
3251 writeOps.push_back(conn);
3257 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3259 if (!mem.getPortAnnotation(i).empty())
3262 switch (mem.getPortKind(i)) {
3263 case MemOp::PortKind::Debug:
3266 case MemOp::PortKind::Write:
3267 if (failed(findWriteUsers(port,
"data")))
3270 case MemOp::PortKind::Read:
3271 if (failed(findReadUsers(port,
"data")))
3274 case MemOp::PortKind::ReadWrite:
3275 if (failed(findWriteUsers(port,
"wdata")))
3277 if (failed(findReadUsers(port,
"rdata")))
3281 llvm_unreachable(
"unknown port kind");
3285 if (usedBits.none())
3289 SmallVector<std::pair<unsigned, unsigned>> ranges;
3290 unsigned newWidth = 0;
3291 for (
int i = usedBits.find_first(); 0 <= i && i < width;) {
3292 int e = usedBits.find_next_unset(i);
3295 for (
int idx = i; idx < e; ++idx, ++newWidth) {
3296 if (
auto it = mapping.find(idx); it != mapping.end()) {
3297 it->second = newWidth;
3300 ranges.emplace_back(i, e - 1);
3301 i = e != width ? usedBits.find_next(e) : e;
3305 auto newType =
IntType::get(mem->getContext(), type.isSigned(), newWidth);
3306 SmallVector<Type> portTypes;
3307 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3308 portTypes.push_back(
3309 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
3311 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
3312 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
3313 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
3314 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
3315 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3318 auto rewriteSubfield = [&](Value port, StringRef field) {
3319 auto portTy = type_cast<BundleType>(port.getType());
3320 auto fieldIndex = portTy.getElementIndex(field);
3321 assert(fieldIndex &&
"missing data port");
3323 rewriter.setInsertionPointAfter(newMem);
3324 auto newPortAccess =
3325 SubfieldOp::create(rewriter, port.getLoc(), port, field);
3327 for (
auto *op :
llvm::make_early_inc_range(port.getUsers())) {
3328 auto portAccess = cast<SubfieldOp>(op);
3329 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
3331 rewriter.replaceOp(portAccess, newPortAccess.getResult());
3336 for (
auto [i, port] :
llvm::enumerate(newMem.getResults())) {
3337 switch (newMem.getPortKind(i)) {
3338 case MemOp::PortKind::Debug:
3339 llvm_unreachable(
"cannot rewrite debug port");
3340 case MemOp::PortKind::Write:
3341 rewriteSubfield(port,
"data");
3343 case MemOp::PortKind::Read:
3344 rewriteSubfield(port,
"data");
3346 case MemOp::PortKind::ReadWrite:
3347 rewriteSubfield(port,
"rdata");
3348 rewriteSubfield(port,
"wdata");
3351 llvm_unreachable(
"unknown port kind");
3355 for (
auto readOp : readOps) {
3356 rewriter.setInsertionPointAfter(readOp);
3357 auto it = mapping.find(readOp.getLo());
3358 assert(it != mapping.end() &&
"bit op mapping not found");
3361 auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
3362 readOp.getLoc(), readOp.getInput(),
3363 readOp.getHi() - readOp.getLo() + it->second, it->second);
3364 rewriter.replaceAllUsesWith(readOp, newReadValue);
3365 rewriter.eraseOp(readOp);
3369 for (
auto writeOp : writeOps) {
3370 Value source = writeOp.getSrc();
3371 rewriter.setInsertionPoint(writeOp);
3373 SmallVector<Value> slices;
3374 for (
auto &[start, end] :
llvm::reverse(ranges)) {
3375 Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
3376 source,
end, start);
3377 slices.push_back(slice);
3381 rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(), slices);
3387 if (type.isSigned())
3389 rewriter.createOrFold<AsSIntPrimOp>(writeOp.getLoc(), catOfSlices);
3391 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
3401 using OpRewritePattern::OpRewritePattern;
3402 LogicalResult matchAndRewrite(MemOp mem,
3403 PatternRewriter &rewriter)
const override {
3408 auto ty = mem.getDataType();
3409 auto loc = mem.getLoc();
3410 auto *block = mem->getBlock();
3414 SmallPtrSet<Operation *, 8> connects;
3415 SmallVector<SubfieldOp> portAccesses;
3416 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3417 if (!mem.getPortAnnotation(i).empty())
3420 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
3421 auto portTy = type_cast<BundleType>(port.getType());
3422 for (
auto field : fields) {
3423 auto fieldIndex = portTy.getElementIndex(field);
3424 assert(fieldIndex &&
"missing field on memory port");
3426 for (
auto *op : port.getUsers()) {
3427 auto portAccess = cast<SubfieldOp>(op);
3428 if (fieldIndex != portAccess.getFieldIndex())
3430 portAccesses.push_back(portAccess);
3431 for (
auto *user : portAccess->getUsers()) {
3432 auto conn = dyn_cast<FConnectLike>(user);
3435 connects.insert(conn);
3442 switch (mem.getPortKind(i)) {
3443 case MemOp::PortKind::Debug:
3445 case MemOp::PortKind::Read:
3446 if (failed(collect({
"clk",
"en",
"addr"})))
3449 case MemOp::PortKind::Write:
3450 if (failed(collect({
"clk",
"en",
"addr",
"data",
"mask"})))
3453 case MemOp::PortKind::ReadWrite:
3454 if (failed(collect({
"clk",
"en",
"addr",
"wmode",
"wdata",
"wmask"})))
3460 if (!portClock || (clock && portClock != clock))
3466 rewriter.setInsertionPointAfter(mem);
3467 auto memWire = WireOp::create(rewriter, loc, ty).getResult();
3473 rewriter.setInsertionPointToEnd(block);
3475 RegOp::create(rewriter, loc, ty, clock, mem.getName()).getResult();
3478 MatchingConnectOp::create(rewriter, loc, memWire, memReg);
3482 auto pipeline = [&](Value value, Value clock,
const Twine &name,
3484 for (
unsigned i = 0; i < latency; ++i) {
3485 std::string regName;
3487 llvm::raw_string_ostream os(regName);
3488 os << mem.getName() <<
"_" << name <<
"_" << i;
3490 auto reg = RegOp::create(rewriter, mem.getLoc(), value.getType(), clock,
3491 rewriter.getStringAttr(regName))
3493 MatchingConnectOp::create(rewriter, value.getLoc(), reg, value);
3499 const unsigned writeStages =
info.writeLatency - 1;
3504 SmallVector<std::tuple<Value, Value, Value>> writes;
3505 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3507 StringRef name = mem.getPortName(i);
3509 auto portPipeline = [&, port = port](StringRef field,
unsigned stages) {
3512 return pipeline(value, portClock, name +
"_" + field, stages);
3515 switch (mem.getPortKind(i)) {
3516 case MemOp::PortKind::Debug:
3517 llvm_unreachable(
"unknown port kind");
3518 case MemOp::PortKind::Read: {
3526 case MemOp::PortKind::Write: {
3527 auto data = portPipeline(
"data", writeStages);
3528 auto en = portPipeline(
"en", writeStages);
3529 auto mask = portPipeline(
"mask", writeStages);
3533 case MemOp::PortKind::ReadWrite: {
3538 auto wdata = portPipeline(
"wdata", writeStages);
3539 auto wmask = portPipeline(
"wmask", writeStages);
3544 auto wen = AndPrimOp::create(rewriter, port.getLoc(),
en,
wmode);
3546 pipeline(wen, portClock, name +
"_wen", writeStages);
3547 writes.emplace_back(
wdata, wenPipelined,
wmask);
3554 Value next = memReg;
3560 Location loc = mem.getLoc();
3561 unsigned maskGran =
info.dataWidth /
info.maskBits;
3562 SmallVector<Value> chunks;
3563 for (
unsigned i = 0; i <
info.maskBits; ++i) {
3564 unsigned hi = (i + 1) * maskGran - 1;
3565 unsigned lo = i * maskGran;
3567 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc,
data, hi, lo);
3568 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3569 auto bit = rewriter.createOrFold<BitsPrimOp>(loc,
mask, i, i);
3570 auto chunk = MuxPrimOp::create(rewriter, loc, bit, dataPart, nextPart);
3571 chunks.push_back(chunk);
3574 std::reverse(chunks.begin(), chunks.end());
3575 masked = rewriter.createOrFold<CatPrimOp>(loc, chunks);
3576 next = MuxPrimOp::create(rewriter, next.getLoc(),
en, masked, next);
3578 Value typedNext = rewriter.createOrFold<BitCastOp>(next.getLoc(), ty, next);
3579 MatchingConnectOp::create(rewriter, memReg.getLoc(), memReg, typedNext);
3582 for (Operation *conn : connects)
3583 rewriter.eraseOp(
conn);
3584 for (
auto portAccess : portAccesses)
3585 rewriter.eraseOp(portAccess);
3586 rewriter.eraseOp(mem);
3593void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3596 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3597 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3617 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3620 auto *high = mux.getHigh().getDefiningOp();
3621 auto *low = mux.getLow().getDefiningOp();
3623 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3630 bool constReg =
false;
3632 if (constOp && low == reg)
3634 else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3636 constOp = dyn_cast<ConstantOp>(low);
3643 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3647 auto regTy = reg.getResult().getType();
3648 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3649 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3650 regTy.getBitWidthOrSentinel() < 0)
3656 if (constOp != &con->getBlock()->front())
3657 constOp->moveBefore(&con->getBlock()->front());
3660 SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3661 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3662 rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3663 mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3664 reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3665 reg.getForceableAttr());
3666 newReg->setDialectAttrs(attrs);
3668 auto pt = rewriter.saveInsertionPoint();
3669 rewriter.setInsertionPoint(con);
3670 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3671 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3672 rewriter.restoreInsertionPoint(pt);
3676LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3677 if (!
hasDontTouch(op.getOperation()) && !op.isForceable() &&
3693 PatternRewriter &rewriter,
3696 if (
auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3697 if (constant.getValue().isZero()) {
3698 rewriter.eraseOp(op);
3704 if (
auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3705 if (constant.getValue().isZero() == eraseIfZero) {
3706 rewriter.eraseOp(op);
3714template <
class Op,
bool EraseIfZero = false>
3716 PatternRewriter &rewriter) {
3721void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3723 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3724 results.add<patterns::AssertXWhenX>(
context);
3727void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3729 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3730 results.add<patterns::AssumeXWhenX>(
context);
3733void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3734 RewritePatternSet &results, MLIRContext *
context) {
3735 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3736 results.add<patterns::UnclockedAssumeIntrinsicXWhenX>(
context);
3739void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3741 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3748LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3749 PatternRewriter &rewriter) {
3751 if (op.use_empty()) {
3752 rewriter.eraseOp(op);
3759 if (op->hasOneUse() &&
3760 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3761 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3762 *op->user_begin()) ||
3763 (isa<CvtPrimOp>(*op->user_begin()) &&
3764 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3765 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3766 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3767 .getBitWidthOrSentinel() > 0))) {
3768 auto *modop = *op->user_begin();
3769 auto inv = InvalidValueOp::create(rewriter, op.getLoc(),
3770 modop->getResult(0).getType());
3771 rewriter.replaceAllOpUsesWith(modop, inv);
3772 rewriter.eraseOp(modop);
3773 rewriter.eraseOp(op);
3779OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3780 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3781 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3789OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3798 return BoolAttr::get(getContext(),
false);
3802 return BoolAttr::get(getContext(),
false);
3807LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3808 PatternRewriter &rewriter) {
3810 if (
auto testEnable = op.getTestEnable()) {
3811 if (
auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3812 if (constOp.getValue().isZero()) {
3813 rewriter.modifyOpInPlace(op,
3814 [&] { op.getTestEnableMutable().clear(); });
3830 auto forceable = op.getRef().getDefiningOp<Forceable>();
3831 if (!forceable || !forceable.isForceable() ||
3832 op.getRef() != forceable.getDataRef() ||
3833 op.getType() != forceable.getDataType())
3835 rewriter.replaceAllUsesWith(op, forceable.getData());
3839void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3841 results.insert<patterns::RefResolveOfRefSend>(
context);
3845OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3847 if (getInput().getType() == getType())
3853 auto constOp = operand.getDefiningOp<ConstantOp>();
3854 return constOp && constOp.getValue().isZero();
3857template <
typename Op>
3860 rewriter.eraseOp(op);
3866void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3868 results.add(eraseIfPredFalse<RefForceOp>);
3870void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3872 results.add(eraseIfPredFalse<RefForceInitialOp>);
3874void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3876 results.add(eraseIfPredFalse<RefReleaseOp>);
3878void RefReleaseInitialOp::getCanonicalizationPatterns(
3879 RewritePatternSet &results, MLIRContext *
context) {
3880 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3887OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3893 if (adaptor.getReset())
3898 if (
isUInt1(getReset().getType()) && adaptor.getClock())
3911 [&](
auto ty) ->
bool {
return isTypeEmpty(ty.getElementType()); })
3912 .Case<BundleType>([&](
auto ty) ->
bool {
3913 for (
auto elem : ty.getElements())
3918 .Case<IntType>([&](
auto ty) {
return ty.getWidth() == 0; })
3919 .Default([](
auto) ->
bool {
return false; });
3922LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3923 PatternRewriter &rewriter) {
3924 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3931 rewriter.eraseOp(op);
3939LogicalResult LayerBlockOp::canonicalize(LayerBlockOp op,
3940 PatternRewriter &rewriter) {
3943 if (op.getBody()->empty()) {
3944 rewriter.eraseOp(op);
3955OpFoldResult UnsafeDomainCastOp::fold(FoldAdaptor adaptor) {
3957 if (getDomains().
empty())
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion)
Replaces the given op with the contents of the given single-block region.
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static std::optional< bool > getBoolValue(Attribute attr)
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static bool isOkToPropagateName(Operation *op)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
static Location getLoc(DefSlot slot)
static InstancePath empty
AndRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
OrRCat(MLIRContext *context)
bool getIdentityValue() const override
Return the unit value for this reduction operation:
virtual bool getIdentityValue() const =0
Return the unit value for this reduction operation:
virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp constantOp, SmallVectorImpl< Value > &remaining) const =0
Handle a constant operand in the cat operation.
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
XorRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
LogicalResult matchAndRewrite(BitsPrimOp bits, mlir::PatternRewriter &rewriter) const override