20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/ADT/APSInt.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallPtrSet.h"
25#include "llvm/ADT/StringExtras.h"
26#include "llvm/ADT/TypeSwitch.h"
29using namespace firrtl;
33static Value
dropWrite(PatternRewriter &rewriter, OpResult old,
35 SmallPtrSet<Operation *, 8> users;
36 for (
auto *user : old.getUsers())
38 for (Operation *user : users)
39 if (
auto connect = dyn_cast<FConnectLike>(user))
40 if (connect.getDest() == old)
41 rewriter.eraseOp(user);
51 if (op->getNumRegions() != 0)
53 return mlir::isPure(op) || isa<NodeOp, WireOp>(op);
61 Operation *op = passthrough.getDefiningOp();
64 assert(op &&
"passthrough must be an operation");
65 Operation *oldOp = old.getOwner();
66 auto name = oldOp->getAttrOfType<StringAttr>(
"name");
68 op->setAttr(
"name", name);
76#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
84 auto resultType = type_cast<IntType>(op->getResult(0).getType());
85 if (!resultType.hasWidth())
87 for (Value operand : op->getOperands())
88 if (!type_cast<IntType>(operand.getType()).hasWidth())
95 auto t = type_dyn_cast<UIntType>(type);
96 if (!t || !t.hasWidth() || t.getWidth() != 1)
103static void updateName(PatternRewriter &rewriter, Operation *op,
108 assert((!isa<InstanceOp, RegOp, RegResetOp>(op)) &&
"Should never rename");
109 auto newName = name.getValue();
110 auto newOpName = op->getAttrOfType<StringAttr>(
"name");
113 newName =
chooseName(newOpName.getValue(), name.getValue());
115 if (!newOpName || newOpName.getValue() != newName)
116 rewriter.modifyOpInPlace(
117 op, [&] { op->setAttr(
"name", rewriter.getStringAttr(newName)); });
125 if (
auto *newOp = newValue.getDefiningOp()) {
126 auto name = op->getAttrOfType<StringAttr>(
"name");
129 rewriter.replaceOp(op, newValue);
135template <
typename OpTy,
typename... Args>
137 Operation *op, Args &&...args) {
138 auto name = op->getAttrOfType<StringAttr>(
"name");
140 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
148 if (
auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
149 return namableOp.hasDroppableName();
160static std::optional<APSInt>
162 assert(type_cast<IntType>(operand.getType()) &&
163 "getExtendedConstant is limited to integer types");
170 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
175 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
176 return APSInt(destWidth,
177 type_cast<IntType>(operand.getType()).isUnsigned());
185 if (
auto attr = dyn_cast<BoolAttr>(operand))
186 return APSInt(APInt(1, attr.getValue()));
187 if (
auto attr = dyn_cast<IntegerAttr>(operand))
188 return attr.getAPSInt();
196 return cst->isZero();
223 Operation *op, ArrayRef<Attribute> operands,
BinOpKind opKind,
224 const function_ref<APInt(
const APSInt &,
const APSInt &)> &calculate) {
225 assert(operands.size() == 2 &&
"binary op takes two operands");
228 auto resultType = type_cast<IntType>(op->getResult(0).getType());
229 if (resultType.getWidthOrSentinel() < 0)
233 if (resultType.getWidthOrSentinel() == 0)
234 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
240 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
242 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
243 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
244 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
245 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
246 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
250 int32_t operandWidth;
253 operandWidth = resultType.getWidthOrSentinel();
258 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
262 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
273 APInt resultValue = calculate(*lhs, *rhs);
278 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
280 assert((
unsigned)resultType.getWidthOrSentinel() ==
281 resultValue.getBitWidth());
294 Operation *op, PatternRewriter &rewriter,
295 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
297 if (op->getNumResults() != 1)
299 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
304 auto width = type.getBitWidthOrSentinel();
309 SmallVector<Attribute, 3> constOperands;
310 constOperands.reserve(op->getNumOperands());
311 for (
auto operand : op->getOperands()) {
313 if (
auto *defOp = operand.getDefiningOp())
314 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
315 [&](
auto op) { attr = op.getValueAttr(); });
316 constOperands.push_back(attr);
321 auto result = canonicalize(constOperands);
325 if (
auto cst = dyn_cast<Attribute>(result))
326 resultValue = op->getDialect()
327 ->materializeConstant(rewriter, cst, type, op->getLoc())
330 resultValue = cast<Value>(result);
334 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
335 resultValue = PadPrimOp::create(rewriter, op->getLoc(), resultValue, width);
338 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
339 resultValue = AsSIntPrimOp::create(rewriter, op->getLoc(), resultValue);
340 else if (type_isa<UIntType>(type) &&
341 type_isa<SIntType>(resultValue.getType()))
342 resultValue = AsUIntPrimOp::create(rewriter, op->getLoc(), resultValue);
344 assert(type == resultValue.getType() &&
"canonicalization changed type");
352 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
358 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
364 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
371OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
372 assert(adaptor.getOperands().empty() &&
"constant has no operands");
373 return getValueAttr();
376OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
377 assert(adaptor.getOperands().empty() &&
"constant has no operands");
378 return getValueAttr();
381OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
382 assert(adaptor.getOperands().empty() &&
"constant has no operands");
383 return getFieldsAttr();
386OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
387 assert(adaptor.getOperands().empty() &&
"constant has no operands");
388 return getValueAttr();
391OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
392 assert(adaptor.getOperands().empty() &&
"constant has no operands");
393 return getValueAttr();
396OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
397 assert(adaptor.getOperands().empty() &&
"constant has no operands");
398 return getValueAttr();
401OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
402 assert(adaptor.getOperands().empty() &&
"constant has no operands");
403 return getValueAttr();
410OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
413 [=](
const APSInt &a,
const APSInt &b) { return a + b; });
416void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
418 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
419 patterns::AddOfSelf, patterns::AddOfPad>(
context);
422OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
425 [=](
const APSInt &a,
const APSInt &b) { return a - b; });
428void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
430 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
431 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
432 patterns::SubOfPadL, patterns::SubOfPadR>(
context);
435OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
447 [=](
const APSInt &a,
const APSInt &b) { return a * b; });
450OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
457 if (getLhs() == getRhs()) {
458 auto width = getType().base().getWidthOrSentinel();
463 return getIntAttr(getType(), APInt(width, 1));
480 if (
auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
481 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
486 [=](
const APSInt &a,
const APSInt &b) -> APInt {
489 return APInt(a.getBitWidth(), 0);
493OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
500 if (getLhs() == getRhs())
514 [=](
const APSInt &a,
const APSInt &b) -> APInt {
517 return APInt(a.getBitWidth(), 0);
521OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
524 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
527OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
530 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
533OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
536 [=](
const APSInt &a,
const APSInt &b) -> APInt {
537 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
543OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
546 if (rhsCst->isZero())
550 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
551 getRhs().getType() == getType())
557 if (lhsCst->isZero())
561 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
562 getRhs().getType() == getType())
567 if (getLhs() == getRhs() && getRhs().getType() == getType())
572 [](
const APSInt &a,
const APSInt &b) -> APInt { return a & b; });
575void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
578 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
579 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
580 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(
context);
583OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
586 if (rhsCst->isZero() && getLhs().getType() == getType())
590 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
591 getLhs().getType() == getType())
597 if (lhsCst->isZero() && getRhs().getType() == getType())
601 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
602 getRhs().getType() == getType())
607 if (getLhs() == getRhs() && getRhs().getType() == getType())
612 [](
const APSInt &a,
const APSInt &b) -> APInt { return a | b; });
615void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
617 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
618 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
622OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
625 if (rhsCst->isZero() &&
631 if (lhsCst->isZero() &&
636 if (getLhs() == getRhs())
639 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
643 [](
const APSInt &a,
const APSInt &b) -> APInt { return a ^ b; });
646void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
648 results.insert<patterns::extendXor, patterns::moveConstXor,
649 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
653void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
655 results.insert<patterns::LEQWithConstLHS>(
context);
658OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
659 bool isUnsigned = getLhs().getType().base().isUnsigned();
662 if (getLhs() == getRhs())
666 if (
auto width = getLhs().getType().base().
getWidth()) {
668 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
669 commonWidth = std::max(commonWidth, 1);
680 if (isUnsigned && rhsCst->zext(commonWidth)
693 [=](
const APSInt &a,
const APSInt &b) -> APInt {
694 return APInt(1, a <= b);
698void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
700 results.insert<patterns::LTWithConstLHS>(
context);
703OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
704 IntType lhsType = getLhs().getType();
708 if (getLhs() == getRhs())
718 if (
auto width = lhsType.
getWidth()) {
720 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
721 commonWidth = std::max(commonWidth, 1);
732 if (isUnsigned && rhsCst->zext(commonWidth)
745 [=](
const APSInt &a,
const APSInt &b) -> APInt {
746 return APInt(1, a < b);
750void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
752 results.insert<patterns::GEQWithConstLHS>(
context);
755OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
756 IntType lhsType = getLhs().getType();
760 if (getLhs() == getRhs())
765 if (rhsCst->isZero() && isUnsigned)
770 if (
auto width = lhsType.
getWidth()) {
772 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
773 commonWidth = std::max(commonWidth, 1);
776 if (isUnsigned && rhsCst->zext(commonWidth)
797 [=](
const APSInt &a,
const APSInt &b) -> APInt {
798 return APInt(1, a >= b);
802void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
804 results.insert<patterns::GTWithConstLHS>(
context);
807OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
808 IntType lhsType = getLhs().getType();
812 if (getLhs() == getRhs())
816 if (
auto width = lhsType.
getWidth()) {
818 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
819 commonWidth = std::max(commonWidth, 1);
822 if (isUnsigned && rhsCst->zext(commonWidth)
843 [=](
const APSInt &a,
const APSInt &b) -> APInt {
844 return APInt(1, a > b);
848OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
850 if (getLhs() == getRhs())
856 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
857 getRhs().getType() == getType())
863 [=](
const APSInt &a,
const APSInt &b) -> APInt {
864 return APInt(1, a == b);
868LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
870 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
872 auto width = op.getLhs().getType().getBitWidthOrSentinel();
875 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
876 op.getRhs().getType() == op.getType()) {
877 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
882 if (rhsCst->isZero() && width > 1) {
883 auto orrOp = OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
884 return NotPrimOp::create(rewriter, op.getLoc(), orrOp).getResult();
888 if (rhsCst->isAllOnes() && width > 1 &&
889 op.getLhs().getType() == op.getRhs().getType()) {
890 return AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
898OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
900 if (getLhs() == getRhs())
906 if (rhsCst->isZero() && getLhs().getType() == getType() &&
907 getRhs().getType() == getType())
913 [=](
const APSInt &a,
const APSInt &b) -> APInt {
914 return APInt(1, a != b);
918LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
920 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
922 auto width = op.getLhs().getType().getBitWidthOrSentinel();
925 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
926 op.getRhs().getType() == op.getType()) {
927 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
932 if (rhsCst->isZero() && width > 1) {
933 return OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
938 if (rhsCst->isAllOnes() && width > 1 &&
939 op.getLhs().getType() == op.getRhs().getType()) {
941 AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
942 return NotPrimOp::create(rewriter, op.getLoc(), andrOp).getResult();
950OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
956OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
962OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
966 return IntegerAttr::get(IntegerType::get(getContext(),
967 lhsCst->getBitWidth(),
968 IntegerType::Signed),
969 lhsCst->ashr(*rhsCst));
972 if (rhsCst->isZero())
979OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
984 return IntegerAttr::get(IntegerType::get(getContext(),
985 lhsCst->getBitWidth(),
986 IntegerType::Signed),
987 lhsCst->shl(*rhsCst));
990 if (rhsCst->isZero())
1001OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
1002 auto base = getInput().getType();
1009OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
1016OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
1024 if (getType().base().hasWidth())
1031void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1033 results.insert<patterns::StoUtoS>(
context);
1036OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1044 if (getType().base().hasWidth())
1051void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1053 results.insert<patterns::UtoStoU>(
context);
1056OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1058 if (getInput().getType() == getType())
1063 return BoolAttr::get(getContext(), cst->getBoolValue());
1068OpFoldResult AsResetPrimOp::fold(FoldAdaptor adaptor) {
1070 return BoolAttr::get(getContext(), cst->getBoolValue());
1074OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1076 if (getInput().getType() == getType())
1081 return BoolAttr::get(getContext(), cst->getBoolValue());
1086OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1092 getType().base().getWidthOrSentinel()))
1098void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1100 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(
context);
1103OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1110 getType().base().getWidthOrSentinel()))
1111 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1116OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1121 getType().base().getWidthOrSentinel()))
1127void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1129 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1130 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1137 : RewritePattern(opName, 0,
context) {}
1143 ConstantOp constantOp,
1144 SmallVectorImpl<Value> &remaining)
const = 0;
1151 mlir::PatternRewriter &rewriter)
const override {
1153 auto catOp = op->getOperand(0).getDefiningOp<CatPrimOp>();
1157 SmallVector<Value> nonConstantOperands;
1160 for (
auto operand : catOp.getInputs()) {
1161 if (
auto constantOp = operand.getDefiningOp<ConstantOp>()) {
1163 if (
handleConstant(rewriter, op, constantOp, nonConstantOperands))
1167 nonConstantOperands.push_back(operand);
1172 if (nonConstantOperands.empty()) {
1173 replaceOpWithNewOpAndCopyName<ConstantOp>(
1174 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1180 if (nonConstantOperands.size() == 1) {
1181 rewriter.modifyOpInPlace(
1182 op, [&] { op->setOperand(0, nonConstantOperands.front()); });
1187 if (catOp->hasOneUse() &&
1188 nonConstantOperands.size() < catOp->getNumOperands()) {
1189 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, catOp,
1190 nonConstantOperands);
1203 SmallVectorImpl<Value> &remaining)
const override {
1204 if (value.getValue().isZero())
1207 replaceOpWithNewOpAndCopyName<ConstantOp>(
1208 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1221 SmallVectorImpl<Value> &remaining)
const override {
1222 if (value.getValue().isAllOnes())
1225 replaceOpWithNewOpAndCopyName<ConstantOp>(
1226 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1239 SmallVectorImpl<Value> &remaining)
const override {
1240 if (value.getValue().isZero())
1242 remaining.push_back(value);
1248OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1252 if (getInput().getType().getBitWidthOrSentinel() == 0)
1257 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1261 if (
isUInt1(getInput().getType()))
1267void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1269 results.insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1270 patterns::AndRPadS, patterns::AndRCatAndR_left,
1274OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1278 if (getInput().getType().getBitWidthOrSentinel() == 0)
1283 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1287 if (
isUInt1(getInput().getType()))
1293void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1295 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1296 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right,
OrRCat>(
1300OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1304 if (getInput().getType().getBitWidthOrSentinel() == 0)
1309 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1312 if (
isUInt1(getInput().getType()))
1318void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1321 .insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1322 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right,
XorRCat>(
1330OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1331 auto inputs = getInputs();
1332 auto inputAdaptors = adaptor.getInputs();
1339 if (inputs.size() == 1 && inputs[0].getType() == getType())
1347 SmallVector<Value> nonZeroInputs;
1348 SmallVector<Attribute> nonZeroAttributes;
1349 bool allConstant =
true;
1350 for (
auto [input, attr] :
llvm::zip(inputs, inputAdaptors)) {
1351 auto inputType = type_cast<IntType>(input.getType());
1352 if (inputType.getBitWidthOrSentinel() != 0) {
1353 nonZeroInputs.push_back(input);
1355 allConstant =
false;
1356 if (nonZeroInputs.size() > 1 && !allConstant)
1362 if (nonZeroInputs.empty())
1366 if (nonZeroInputs.size() == 1 && nonZeroInputs[0].getType() == getType())
1367 return nonZeroInputs[0];
1373 SmallVector<APInt> constants;
1374 for (
auto inputAdaptor : inputAdaptors) {
1376 constants.push_back(*cst);
1381 assert(!constants.empty());
1383 APInt result = constants[0];
1384 for (
size_t i = 1; i < constants.size(); ++i)
1385 result = result.concat(constants[i]);
1390void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1392 results.insert<patterns::DShlOfConstant>(
context);
1395void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1397 results.insert<patterns::DShrOfConstant>(
context);
1405 using OpRewritePattern::OpRewritePattern;
1408 matchAndRewrite(CatPrimOp cat,
1409 mlir::PatternRewriter &rewriter)
const override {
1411 cat.getType().getBitWidthOrSentinel() == 0)
1415 if (cat->hasOneUse() && isa<CatPrimOp>(*cat->getUsers().begin()))
1419 SmallVector<Value> operands;
1420 SmallVector<Value> worklist;
1421 auto pushOperands = [&worklist](CatPrimOp op) {
1422 for (
auto operand :
llvm::reverse(op.getInputs()))
1423 worklist.push_back(operand);
1426 bool hasSigned =
false, hasUnsigned =
false;
1427 while (!worklist.empty()) {
1428 auto value = worklist.pop_back_val();
1429 auto catOp = value.getDefiningOp<CatPrimOp>();
1431 operands.push_back(value);
1432 (type_isa<UIntType>(value.getType()) ? hasUnsigned : hasSigned) =
true;
1436 pushOperands(catOp);
1441 auto castToUIntIfSigned = [&](Value value) -> Value {
1442 if (type_isa<UIntType>(value.getType()))
1444 return AsUIntPrimOp::create(rewriter, value.getLoc(), value);
1447 assert(operands.size() >= 1 &&
"zero width cast must be rejected");
1449 if (operands.size() == 1) {
1450 rewriter.replaceOp(cat, castToUIntIfSigned(operands[0]));
1454 if (operands.size() == cat->getNumOperands())
1458 if (hasSigned && hasUnsigned)
1459 for (
auto &operand : operands)
1460 operand = castToUIntIfSigned(operand);
1462 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, cat, cat.getType(),
1471 using OpRewritePattern::OpRewritePattern;
1474 matchAndRewrite(CatPrimOp cat,
1475 mlir::PatternRewriter &rewriter)
const override {
1479 SmallVector<Value> operands;
1481 for (
size_t i = 0; i < cat->getNumOperands(); ++i) {
1482 auto cst = cat.getInputs()[i].getDefiningOp<ConstantOp>();
1484 operands.push_back(cat.getInputs()[i]);
1487 APSInt value = cst.getValue();
1489 for (; j < cat->getNumOperands(); ++j) {
1490 auto nextCst = cat.getInputs()[j].getDefiningOp<ConstantOp>();
1493 value = value.concat(nextCst.getValue());
1498 operands.push_back(cst);
1501 operands.push_back(ConstantOp::create(rewriter, cat.getLoc(), value));
1507 if (operands.size() == cat->getNumOperands())
1510 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, cat, cat.getType(),
1519void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1521 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1522 patterns::CatCast, FlattenCat, CatOfConstant>(
context);
1529OpFoldResult StringConcatOp::fold(FoldAdaptor adaptor) {
1531 if (getInputs().size() == 1)
1532 return getInputs()[0];
1535 if (!llvm::all_of(adaptor.getInputs(), [](Attribute operand) {
1536 return isa_and_nonnull<StringAttr>(operand);
1541 SmallString<64> result;
1542 for (
auto operand : adaptor.getInputs())
1543 result += cast<StringAttr>(operand).getValue();
1545 return StringAttr::get(getContext(), result);
1553 using OpRewritePattern::OpRewritePattern;
1556 matchAndRewrite(StringConcatOp concat,
1557 mlir::PatternRewriter &rewriter)
const override {
1561 bool hasNestedConcat = llvm::any_of(concat.getInputs(), [](Value operand) {
1562 auto nestedConcat = operand.getDefiningOp<StringConcatOp>();
1563 return nestedConcat && operand.hasOneUse();
1566 if (!hasNestedConcat)
1570 SmallVector<Value> flatOperands;
1571 for (
auto input : concat.getInputs()) {
1572 if (
auto nestedConcat = input.getDefiningOp<StringConcatOp>();
1573 nestedConcat && input.hasOneUse())
1574 llvm::append_range(flatOperands, nestedConcat.getInputs());
1576 flatOperands.push_back(input);
1579 rewriter.modifyOpInPlace(concat,
1580 [&]() { concat->setOperands(flatOperands); });
1587class MergeAdjacentStringConstants
1590 using OpRewritePattern::OpRewritePattern;
1593 matchAndRewrite(StringConcatOp concat,
1594 mlir::PatternRewriter &rewriter)
const override {
1596 SmallVector<Value> newOperands;
1597 SmallString<64> accumulatedLit;
1598 SmallVector<StringConstantOp> accumulatedOps;
1599 bool changed =
false;
1601 auto flushLiterals = [&]() {
1602 if (accumulatedOps.empty())
1606 if (accumulatedOps.size() == 1) {
1607 newOperands.push_back(accumulatedOps[0]);
1610 auto newLit = rewriter.createOrFold<StringConstantOp>(
1611 concat.getLoc(), StringAttr::get(getContext(), accumulatedLit));
1612 newOperands.push_back(newLit);
1615 accumulatedLit.clear();
1616 accumulatedOps.clear();
1619 for (
auto operand : concat.getInputs()) {
1620 if (
auto litOp = operand.getDefiningOp<StringConstantOp>()) {
1622 if (litOp.getValue().empty()) {
1626 accumulatedLit += litOp.getValue();
1627 accumulatedOps.push_back(litOp);
1630 newOperands.push_back(operand);
1641 if (newOperands.empty())
1642 return rewriter.replaceOpWithNewOp<StringConstantOp>(
1643 concat, StringAttr::get(getContext(),
"")),
1647 rewriter.modifyOpInPlace(concat,
1648 [&]() { concat->setOperands(newOperands); });
1655void StringConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
1657 results.insert<FlattenStringConcat, MergeAdjacentStringConstants>(
context);
1664OpFoldResult PropEqOp::fold(FoldAdaptor adaptor) {
1665 auto lhsAttr = adaptor.getLhs();
1666 auto rhsAttr = adaptor.getRhs();
1667 if (!lhsAttr || !rhsAttr)
1670 return BoolAttr::get(getContext(), lhsAttr == rhsAttr);
1679 if (
auto boolAttr = dyn_cast_or_null<BoolAttr>(attr))
1680 return boolAttr.getValue();
1681 return std::nullopt;
1684OpFoldResult BoolAndOp::fold(FoldAdaptor adaptor) {
1688 return BoolAttr::get(getContext(), *lhs && *rhs);
1690 if ((lhs && !*lhs) || (rhs && !*rhs))
1691 return BoolAttr::get(getContext(),
false);
1700OpFoldResult BoolOrOp::fold(FoldAdaptor adaptor) {
1704 return BoolAttr::get(getContext(), *lhs || *rhs);
1706 if ((lhs && *lhs) || (rhs && *rhs))
1707 return BoolAttr::get(getContext(),
true);
1716OpFoldResult BoolXorOp::fold(FoldAdaptor adaptor) {
1720 return BoolAttr::get(getContext(), *lhs ^ *rhs);
1729OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1732 if (op.getType() == op.getInput().getType())
1733 return op.getInput();
1737 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1738 if (op.getType() == in.getInput().getType())
1739 return in.getInput();
1744OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1745 IntType inputType = getInput().getType();
1746 IntType resultType = getType();
1748 if (inputType == getType() && resultType.
hasWidth())
1755 cst->extractBits(getHi() - getLo() + 1, getLo()));
1761 using OpRewritePattern::OpRewritePattern;
1765 mlir::PatternRewriter &rewriter)
const override {
1766 auto cat = bits.getInput().getDefiningOp<CatPrimOp>();
1769 int32_t bitPos = bits.getLo();
1770 auto resultWidth = type_cast<UIntType>(bits.getType()).getWidthOrSentinel();
1771 if (resultWidth < 0)
1773 for (
auto operand : llvm::reverse(cat.getInputs())) {
1775 type_cast<IntType>(operand.getType()).getWidthOrSentinel();
1776 if (operandWidth < 0)
1778 if (bitPos < operandWidth) {
1779 if (bitPos + resultWidth <= operandWidth) {
1780 auto newBits = rewriter.createOrFold<BitsPrimOp>(
1781 bits.getLoc(), operand, bitPos + resultWidth - 1, bitPos);
1787 bitPos -= operandWidth;
1793void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1796 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1804 unsigned loBit, PatternRewriter &rewriter) {
1805 auto resType = type_cast<IntType>(op->getResult(0).getType());
1806 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1807 value = BitsPrimOp::create(rewriter, op->getLoc(), value, hiBit, loBit);
1809 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1810 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1811 }
else if (resType.isUnsigned() &&
1812 !type_cast<IntType>(value.getType()).isUnsigned()) {
1813 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1815 rewriter.replaceOp(op, value);
1818template <
typename OpTy>
1819static OpFoldResult
foldMux(OpTy op,
typename OpTy::FoldAdaptor adaptor) {
1821 if (op.getType().getBitWidthOrSentinel() == 0)
1823 APInt(0, 0, op.getType().isSignedInteger()));
1826 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1827 return op.getHigh();
1832 if (op.getType().getBitWidthOrSentinel() < 0)
1837 if (cond->isZero() && op.getLow().getType() == op.getType())
1839 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1840 return op.getHigh();
1844 if (
auto lowCst =
getConstant(adaptor.getLow())) {
1846 if (
auto highCst =
getConstant(adaptor.getHigh())) {
1848 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1849 *highCst == *lowCst)
1852 if (highCst->isOne() && lowCst->isZero() &&
1853 op.getType() == op.getSel().getType())
1866OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1867 return foldMux(*
this, adaptor);
1870OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1871 return foldMux(*
this, adaptor);
1874OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) {
return {}; }
1883 using OpRewritePattern::OpRewritePattern;
1886 matchAndRewrite(MuxPrimOp mux,
1887 mlir::PatternRewriter &rewriter)
const override {
1888 auto width = mux.getType().getBitWidthOrSentinel();
1892 auto pad = [&](Value input) -> Value {
1894 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1895 if (inputWidth < 0 || width == inputWidth)
1897 return PadPrimOp::create(rewriter, mux.getLoc(), mux.getType(), input,
1902 auto newHigh = pad(mux.getHigh());
1903 auto newLow = pad(mux.getLow());
1904 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1907 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1908 rewriter, mux, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1918 using OpRewritePattern::OpRewritePattern;
1920 static const int depthLimit = 5;
1922 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1923 mlir::PatternRewriter &rewriter,
1924 bool updateInPlace)
const {
1925 if (updateInPlace) {
1926 rewriter.modifyOpInPlace(mux, [&] {
1927 mux.setOperand(1, high);
1928 mux.setOperand(2, low);
1932 rewriter.setInsertionPointAfter(mux);
1933 return MuxPrimOp::create(rewriter, mux.getLoc(), mux.getType(),
1934 ValueRange{mux.getSel(), high, low})
1939 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1940 bool updateInPlace,
int limit)
const {
1941 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1944 if (mux.getSel() == cond)
1945 return mux.getHigh();
1946 if (limit > depthLimit)
1948 updateInPlace &= mux->hasOneUse();
1950 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1952 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1955 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1956 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1961 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1962 bool updateInPlace,
int limit)
const {
1963 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1966 if (mux.getSel() == cond)
1967 return mux.getLow();
1968 if (limit > depthLimit)
1970 updateInPlace &= mux->hasOneUse();
1972 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1974 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1976 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1978 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1984 matchAndRewrite(MuxPrimOp mux,
1985 mlir::PatternRewriter &rewriter)
const override {
1986 auto width = mux.getType().getBitWidthOrSentinel();
1990 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter,
true, 0)) {
1991 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1995 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter,
true, 0)) {
1996 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
2005void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
2008 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
2009 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
2010 patterns::MuxSameTrue, patterns::MuxSameFalse,
2011 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
2015void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
2016 RewritePatternSet &results, MLIRContext *
context) {
2017 results.add<patterns::Mux2PadSel>(
context);
2020void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
2021 RewritePatternSet &results, MLIRContext *
context) {
2022 results.add<patterns::Mux4PadSel>(
context);
2025OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
2026 auto input = this->getInput();
2029 if (input.getType() == getType())
2033 auto inputType = input.getType().base();
2040 auto destWidth = getType().base().getWidthOrSentinel();
2041 if (destWidth == -1)
2044 if (inputType.
isSigned() && cst->getBitWidth())
2045 return getIntAttr(getType(), cst->sext(destWidth));
2046 return getIntAttr(getType(), cst->zext(destWidth));
2052OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
2053 auto input = this->getInput();
2054 IntType inputType = input.getType();
2055 int shiftAmount = getAmount();
2058 if (shiftAmount == 0)
2064 if (inputWidth != -1) {
2065 auto resultWidth = inputWidth + shiftAmount;
2066 shiftAmount = std::min(shiftAmount, resultWidth);
2067 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
2073OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
2074 auto input = this->getInput();
2075 IntType inputType = input.getType();
2076 int shiftAmount = getAmount();
2082 if (shiftAmount == 0 && inputWidth > 0)
2085 if (inputWidth == -1)
2087 if (inputWidth == 0)
2092 if (shiftAmount >= inputWidth && inputType.
isUnsigned())
2093 return getIntAttr(getType(), APInt(0, 0,
false));
2099 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
2101 value = cst->lshr(std::min(shiftAmount, inputWidth));
2102 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
2103 return getIntAttr(getType(), value.trunc(resultWidth));
2108LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
2109 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2110 if (inputWidth <= 0)
2114 unsigned shiftAmount = op.getAmount();
2115 if (
int(shiftAmount) >= inputWidth) {
2117 if (op.getType().base().isUnsigned())
2123 shiftAmount = inputWidth - 1;
2126 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
2130LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
2131 PatternRewriter &rewriter) {
2132 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2133 if (inputWidth <= 0)
2137 unsigned keepAmount = op.getAmount();
2139 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
2144OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
2148 getInput().getType().base().getWidthOrSentinel() - getAmount();
2149 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
2155OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
2159 cst->trunc(getType().base().getWidthOrSentinel()));
2163LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
2164 PatternRewriter &rewriter) {
2165 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2166 if (inputWidth <= 0)
2170 unsigned dropAmount = op.getAmount();
2171 if (dropAmount !=
unsigned(inputWidth))
2177void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
2179 results.add<patterns::SubaccessOfConstant>(
context);
2182OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
2184 if (adaptor.getInputs().size() == 1)
2185 return getOperand(1);
2187 if (
auto constIndex =
getConstant(adaptor.getIndex())) {
2188 auto index = constIndex->getZExtValue();
2189 if (index < getInputs().size())
2190 return getInputs()[getInputs().size() - 1 - index];
2196LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
2197 PatternRewriter &rewriter) {
2201 if (llvm::all_of(op.getInputs().drop_front(), [&](
auto input) {
2202 return input == op.getInputs().front();
2210 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
2211 uint64_t inputSize = op.getInputs().size();
2212 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
2213 rewriter.modifyOpInPlace(op, [&]() {
2214 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
2221 if (
auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
2222 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](
auto e) {
2223 auto subindex = e.value().template getDefiningOp<SubindexOp>();
2224 return subindex && lastSubindex.getInput() == subindex.getInput() &&
2225 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
2227 replaceOpWithNewOpAndCopyName<SubaccessOp>(
2228 rewriter, op, lastSubindex.getInput(), op.getIndex());
2234 if (op.getInputs().size() != 2)
2238 auto uintType = op.getIndex().getType();
2239 if (uintType.getBitWidthOrSentinel() != 1)
2243 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
2244 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
2263 MatchingConnectOp connect;
2264 for (Operation *user : value.getUsers()) {
2266 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
2269 if (
auto aConnect = dyn_cast<FConnectLike>(user))
2270 if (aConnect.getDest() == value) {
2271 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
2274 if (!matchingConnect || (connect && connect != matchingConnect) ||
2275 matchingConnect->getBlock() != value.getParentBlock())
2277 connect = matchingConnect;
2285 PatternRewriter &rewriter) {
2288 Operation *connectedDecl = op.getDest().getDefiningOp();
2293 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
2297 cast<Forceable>(connectedDecl).isForceable())
2305 if (connectedDecl->hasOneUse())
2309 auto *declBlock = connectedDecl->getBlock();
2310 auto *srcValueOp = op.getSrc().getDefiningOp();
2313 if (!isa<WireOp>(connectedDecl))
2319 if (!isa<ConstantOp>(srcValueOp))
2321 if (srcValueOp->getBlock() != declBlock)
2327 auto replacement = op.getSrc();
2330 if (srcValueOp && srcValueOp != &declBlock->front())
2331 srcValueOp->moveBefore(&declBlock->front());
2338 rewriter.eraseOp(op);
2342void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2344 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
2348LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
2349 PatternRewriter &rewriter) {
2366 for (
auto *user : value.getUsers()) {
2367 auto attach = dyn_cast<AttachOp>(user);
2368 if (!attach || attach == dominatedAttach)
2370 if (attach->isBeforeInBlock(dominatedAttach))
2376LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
2378 if (op.getNumOperands() <= 1) {
2379 rewriter.eraseOp(op);
2383 for (
auto operand : op.getOperands()) {
2390 SmallVector<Value> newOperands(op.getOperands());
2391 for (
auto newOperand : attach.getOperands())
2392 if (newOperand != operand)
2393 newOperands.push_back(newOperand);
2394 AttachOp::create(rewriter, op->getLoc(), newOperands);
2395 rewriter.eraseOp(attach);
2396 rewriter.eraseOp(op);
2404 if (
auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
2405 if (!
hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
2406 !wire.isForceable()) {
2407 SmallVector<Value> newOperands;
2408 for (
auto newOperand : op.getOperands())
2409 if (newOperand != operand)
2410 newOperands.push_back(newOperand);
2412 AttachOp::create(rewriter, op->getLoc(), newOperands);
2413 rewriter.eraseOp(op);
2414 rewriter.eraseOp(wire);
2425 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
2426 rewriter.inlineBlockBefore(®ion.front(), op, {});
2429LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
2430 if (
auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
2431 if (constant.getValue().isAllOnes())
2433 else if (op.hasElseRegion() && !op.getElseRegion().empty())
2436 rewriter.eraseOp(op);
2442 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
2443 op.getElseBlock().empty()) {
2444 rewriter.eraseBlock(&op.getElseBlock());
2451 if (!op.getThenBlock().empty())
2455 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
2456 rewriter.eraseOp(op);
2466 using OpRewritePattern::OpRewritePattern;
2467 LogicalResult matchAndRewrite(NodeOp node,
2468 PatternRewriter &rewriter)
const override {
2469 auto name = node.getNameAttr();
2470 if (!node.hasDroppableName() || node.getInnerSym() ||
2473 auto *newOp = node.getInput().getDefiningOp();
2476 rewriter.replaceOp(node, node.getInput());
2483 using OpRewritePattern::OpRewritePattern;
2484 LogicalResult matchAndRewrite(NodeOp node,
2485 PatternRewriter &rewriter)
const override {
2487 node.use_empty() || node.isForceable())
2489 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
2496template <
typename OpTy>
2498 PatternRewriter &rewriter) {
2499 if (!op.isForceable() || !op.getDataRef().use_empty())
2507LogicalResult NodeOp::fold(FoldAdaptor adaptor,
2508 SmallVectorImpl<OpFoldResult> &results) {
2517 if (!adaptor.getInput())
2520 results.push_back(adaptor.getInput());
2524void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2526 results.insert<FoldNodeName>(
context);
2527 results.add(demoteForceableIfUnused<NodeOp>);
2533struct AggOneShot :
public mlir::RewritePattern {
2534 AggOneShot(StringRef name, uint32_t weight, MLIRContext *
context)
2535 : RewritePattern(name, 0,
context) {}
2537 SmallVector<Value> getCompleteWrite(Operation *lhs)
const {
2538 auto lhsTy = lhs->getResult(0).getType();
2539 if (!type_isa<BundleType, FVectorType>(lhsTy))
2542 DenseMap<uint32_t, Value> fields;
2543 for (Operation *user : lhs->getResult(0).getUsers()) {
2544 if (user->getParentOp() != lhs->getParentOp())
2546 if (
auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2547 if (aConnect.getDest() == lhs->getResult(0))
2549 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2550 for (Operation *subuser : subField.getResult().getUsers()) {
2551 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2552 if (aConnect.getDest() == subField) {
2553 if (subuser->getParentOp() != lhs->getParentOp())
2555 if (fields.count(subField.getFieldIndex()))
2557 fields[subField.getFieldIndex()] = aConnect.getSrc();
2563 }
else if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2564 for (Operation *subuser : subIndex.getResult().getUsers()) {
2565 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2566 if (aConnect.getDest() == subIndex) {
2567 if (subuser->getParentOp() != lhs->getParentOp())
2569 if (fields.count(subIndex.getIndex()))
2571 fields[subIndex.getIndex()] = aConnect.getSrc();
2582 SmallVector<Value> values;
2583 uint32_t total = type_isa<BundleType>(lhsTy)
2584 ? type_cast<BundleType>(lhsTy).getNumElements()
2585 : type_cast<FVectorType>(lhsTy).getNumElements();
2586 for (uint32_t i = 0; i < total; ++i) {
2587 if (!fields.count(i))
2589 values.push_back(fields[i]);
2594 LogicalResult matchAndRewrite(Operation *op,
2595 PatternRewriter &rewriter)
const override {
2596 auto values = getCompleteWrite(op);
2599 rewriter.setInsertionPointToEnd(op->getBlock());
2600 auto dest = op->getResult(0);
2601 auto destType = dest.getType();
2604 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2607 Value newVal = type_isa<BundleType>(destType)
2608 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2610 : rewriter.createOrFold<VectorCreateOp>(
2611 op->
getLoc(), destType, values);
2612 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2613 for (Operation *user : dest.getUsers()) {
2614 if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2615 for (Operation *subuser :
2616 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2617 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2618 if (aConnect.getDest() == subIndex)
2619 rewriter.eraseOp(aConnect);
2620 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2621 for (Operation *subuser :
2622 llvm::make_early_inc_range(subField.getResult().getUsers()))
2623 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2624 if (aConnect.getDest() == subField)
2625 rewriter.eraseOp(aConnect);
2632struct WireAggOneShot :
public AggOneShot {
2633 WireAggOneShot(MLIRContext *
context)
2634 : AggOneShot(WireOp::getOperationName(), 0,
context) {}
2636struct SubindexAggOneShot :
public AggOneShot {
2637 SubindexAggOneShot(MLIRContext *
context)
2638 : AggOneShot(SubindexOp::getOperationName(), 0,
context) {}
2640struct SubfieldAggOneShot :
public AggOneShot {
2641 SubfieldAggOneShot(MLIRContext *
context)
2642 : AggOneShot(SubfieldOp::getOperationName(), 0,
context) {}
2646void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2648 results.insert<WireAggOneShot>(
context);
2649 results.add(demoteForceableIfUnused<WireOp>);
2652void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2654 results.insert<SubindexAggOneShot>(
context);
2657OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2658 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2661 return attr[getIndex()];
2664OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2665 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2668 auto index = getFieldIndex();
2672void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2674 results.insert<SubfieldAggOneShot>(
context);
2678 ArrayRef<Attribute> operands) {
2679 for (
auto operand : operands)
2682 return ArrayAttr::get(
context, operands);
2685OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2688 if (getNumOperands() > 0)
2689 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2690 if (first.getFieldIndex() == 0 &&
2691 first.getInput().getType() == getType() &&
2693 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2695 elem.value().
template getDefiningOp<SubfieldOp>();
2696 return subindex && subindex.getInput() == first.getInput() &&
2697 subindex.getFieldIndex() == elem.index();
2699 return first.getInput();
2704OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2707 if (getNumOperands() > 0)
2708 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2709 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2711 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2713 elem.value().
template getDefiningOp<SubindexOp>();
2714 return subindex && subindex.getInput() == first.getInput() &&
2715 subindex.getIndex() == elem.index();
2717 return first.getInput();
2722OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2723 if (getOperand().getType() == getType())
2724 return getOperand();
2732 using OpRewritePattern::OpRewritePattern;
2733 LogicalResult matchAndRewrite(RegResetOp reg,
2734 PatternRewriter &rewriter)
const override {
2736 dyn_cast_or_null<ConstantOp>(
reg.getResetValue().getDefiningOp());
2745 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2748 auto *high = mux.getHigh().getDefiningOp();
2749 auto *low = mux.getLow().getDefiningOp();
2750 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2752 if (constOp && low != reg)
2754 if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2755 constOp = dyn_cast<ConstantOp>(low);
2757 if (!constOp || constOp.getType() != reset.getType() ||
2758 constOp.getValue() != reset.getValue())
2762 auto regTy =
reg.getResult().getType();
2763 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2764 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2765 regTy.getBitWidthOrSentinel() < 0)
2771 if (constOp != &con->getBlock()->front())
2772 constOp->moveBefore(&con->getBlock()->front());
2777 rewriter.eraseOp(con);
2784 if (
auto c = v.getDefiningOp<ConstantOp>())
2785 return c.getValue().isOne();
2786 if (
auto sc = v.getDefiningOp<SpecialConstantOp>())
2787 return sc.getValue();
2796 auto resetValue = reg.getResetValue();
2797 if (reg.getType(0) != resetValue.getType())
2801 (void)
dropWrite(rewriter, reg->getResult(0), {});
2802 replaceOpWithNewOpAndCopyName<NodeOp>(
2803 rewriter, reg, resetValue, reg.getNameAttr(), reg.getNameKind(),
2804 reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2808void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2810 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(
context);
2812 results.add(demoteForceableIfUnused<RegResetOp>);
2817 auto portTy = type_cast<BundleType>(port.getType());
2818 auto fieldIndex = portTy.getElementIndex(name);
2819 assert(fieldIndex &&
"missing field on memory port");
2822 for (
auto *op : port.getUsers()) {
2823 auto portAccess = cast<SubfieldOp>(op);
2824 if (fieldIndex != portAccess.getFieldIndex())
2829 value = conn.getSrc();
2839 auto portConst = value.getDefiningOp<ConstantOp>();
2842 return portConst.getValue().isZero();
2847 auto portTy = type_cast<BundleType>(port.getType());
2848 auto fieldIndex = portTy.getElementIndex(
data);
2849 assert(fieldIndex &&
"missing enable flag on memory port");
2851 for (
auto *op : port.getUsers()) {
2852 auto portAccess = cast<SubfieldOp>(op);
2853 if (fieldIndex != portAccess.getFieldIndex())
2855 if (!portAccess.use_empty())
2864 StringRef name, Value value) {
2865 auto portTy = type_cast<BundleType>(port.getType());
2866 auto fieldIndex = portTy.getElementIndex(name);
2867 assert(fieldIndex &&
"missing field on memory port");
2869 for (
auto *op : llvm::make_early_inc_range(port.getUsers())) {
2870 auto portAccess = cast<SubfieldOp>(op);
2871 if (fieldIndex != portAccess.getFieldIndex())
2873 rewriter.replaceAllUsesWith(portAccess, value);
2874 rewriter.eraseOp(portAccess);
2879static void erasePort(PatternRewriter &rewriter, Value port) {
2882 auto getClock = [&] {
2884 clock = SpecialConstantOp::create(rewriter, port.getLoc(),
2885 ClockType::get(rewriter.getContext()),
2894 for (
auto *op : port.getUsers()) {
2895 auto subfield = dyn_cast<SubfieldOp>(op);
2897 auto ty = port.getType();
2898 auto reg = RegOp::create(rewriter, port.getLoc(), ty, getClock());
2899 rewriter.replaceAllUsesWith(port, reg.getResult());
2908 for (
auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2909 auto access = cast<SubfieldOp>(accessOp);
2910 for (
auto *user : llvm::make_early_inc_range(access->getUsers())) {
2911 auto connect = dyn_cast<FConnectLike>(user);
2912 if (connect && connect.getDest() == access) {
2913 rewriter.eraseOp(user);
2917 if (access.use_empty()) {
2918 rewriter.eraseOp(access);
2924 auto ty = access.getType();
2925 auto reg = RegOp::create(rewriter, access.getLoc(), ty, getClock());
2926 rewriter.replaceOp(access, reg.getResult());
2928 assert(port.use_empty() &&
"port should have no remaining uses");
2934 using OpRewritePattern::OpRewritePattern;
2935 LogicalResult matchAndRewrite(MemOp mem,
2936 PatternRewriter &rewriter)
const override {
2940 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2941 mem.getDataType().getBitWidthOrSentinel() != 0)
2945 for (
auto port : mem.getResults())
2946 for (auto *user : port.getUsers())
2947 if (!isa<SubfieldOp>(user))
2952 for (
auto port : mem.getResults()) {
2953 for (
auto *user :
llvm::make_early_inc_range(port.getUsers())) {
2954 SubfieldOp sfop = cast<SubfieldOp>(user);
2955 StringRef fieldName = sfop.getFieldName();
2956 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2957 rewriter, sfop, sfop.getResult().getType())
2959 if (fieldName.ends_with(
"data")) {
2961 auto zero = firrtl::ConstantOp::create(
2962 rewriter, wire.getLoc(),
2963 firrtl::type_cast<IntType>(wire.getType()), APInt::getZero(0));
2964 MatchingConnectOp::create(rewriter, wire.getLoc(), wire, zero);
2968 rewriter.eraseOp(mem);
2975 using OpRewritePattern::OpRewritePattern;
2976 LogicalResult matchAndRewrite(MemOp mem,
2977 PatternRewriter &rewriter)
const override {
2980 bool isRead =
false, isWritten =
false;
2981 for (
unsigned i = 0; i < mem.getNumResults(); ++i) {
2982 switch (mem.getPortKind(i)) {
2983 case MemOp::PortKind::Read:
2988 case MemOp::PortKind::Write:
2993 case MemOp::PortKind::Debug:
2994 case MemOp::PortKind::ReadWrite:
2997 llvm_unreachable(
"unknown port kind");
2999 assert((!isWritten || !isRead) &&
"memory is in use");
3004 if (isRead && mem.getInit())
3007 for (
auto port : mem.getResults())
3010 rewriter.eraseOp(mem);
3017 using OpRewritePattern::OpRewritePattern;
3018 LogicalResult matchAndRewrite(MemOp mem,
3019 PatternRewriter &rewriter)
const override {
3023 llvm::SmallBitVector deadPorts(mem.getNumResults());
3024 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3026 if (!mem.getPortAnnotation(i).empty())
3030 auto kind = mem.getPortKind(i);
3031 if (kind == MemOp::PortKind::Debug)
3040 if (kind == MemOp::PortKind::Read &&
isPortUnused(port,
"data")) {
3045 if (deadPorts.none())
3049 SmallVector<Type> resultTypes;
3050 SmallVector<StringRef> portNames;
3051 SmallVector<Attribute> portAnnotations;
3052 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3055 resultTypes.push_back(port.getType());
3056 portNames.push_back(mem.getPortName(i));
3057 portAnnotations.push_back(mem.getPortAnnotation(i));
3061 if (!resultTypes.empty())
3062 newOp = MemOp::create(
3063 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
3064 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
3065 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
3066 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
3067 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3070 unsigned nextPort = 0;
3071 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3075 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
3078 rewriter.eraseOp(mem);
3085 using OpRewritePattern::OpRewritePattern;
3086 LogicalResult matchAndRewrite(MemOp mem,
3087 PatternRewriter &rewriter)
const override {
3092 llvm::SmallBitVector deadReads(mem.getNumResults());
3093 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3094 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
3096 if (!mem.getPortAnnotation(i).empty())
3103 if (deadReads.none())
3106 SmallVector<Type> resultTypes;
3107 SmallVector<StringRef> portNames;
3108 SmallVector<Attribute> portAnnotations;
3109 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3111 resultTypes.push_back(
3112 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
3113 MemOp::PortKind::Write, mem.getMaskBits()));
3115 resultTypes.push_back(port.getType());
3117 portNames.push_back(mem.getPortName(i));
3118 portAnnotations.push_back(mem.getPortAnnotation(i));
3121 auto newOp = MemOp::create(
3122 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
3123 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
3124 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
3125 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
3126 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3128 for (
unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
3129 auto result = mem.getResult(i);
3130 auto newResult = newOp.getResult(i);
3132 auto resultPortTy = type_cast<BundleType>(result.getType());
3136 auto replace = [&](StringRef toName, StringRef fromName) {
3137 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
3138 assert(fromFieldIndex &&
"missing enable flag on memory port");
3140 auto toField = SubfieldOp::create(rewriter, newResult.getLoc(),
3142 for (
auto *op :
llvm::make_early_inc_range(result.getUsers())) {
3143 auto fromField = cast<SubfieldOp>(op);
3144 if (fromFieldIndex != fromField.getFieldIndex())
3146 rewriter.replaceOp(fromField, toField.getResult());
3150 replace(
"addr",
"addr");
3151 replace(
"en",
"en");
3152 replace(
"clk",
"clk");
3153 replace(
"data",
"wdata");
3154 replace(
"mask",
"wmask");
3157 auto wmodeFieldIndex = resultPortTy.getElementIndex(
"wmode");
3158 for (
auto *op :
llvm::make_early_inc_range(result.getUsers())) {
3159 auto wmodeField = cast<SubfieldOp>(op);
3160 if (wmodeFieldIndex != wmodeField.getFieldIndex())
3162 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
3165 rewriter.replaceAllUsesWith(result, newResult);
3168 rewriter.eraseOp(mem);
3175 using OpRewritePattern::OpRewritePattern;
3177 LogicalResult matchAndRewrite(MemOp mem,
3178 PatternRewriter &rewriter)
const override {
3183 const auto &summary = mem.getSummary();
3184 if (summary.isMasked || summary.isSeqMem())
3187 auto type = type_dyn_cast<IntType>(mem.getDataType());
3190 auto width = type.getBitWidthOrSentinel();
3194 llvm::SmallBitVector usedBits(width);
3195 DenseMap<unsigned, unsigned> mapping;
3200 SmallVector<BitsPrimOp> readOps;
3201 auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
3202 auto portTy = type_cast<BundleType>(port.getType());
3203 auto fieldIndex = portTy.getElementIndex(field);
3204 assert(fieldIndex &&
"missing data port");
3206 for (
auto *op : port.getUsers()) {
3207 auto portAccess = cast<SubfieldOp>(op);
3208 if (fieldIndex != portAccess.getFieldIndex())
3211 for (
auto *user : op->getUsers()) {
3212 auto bits = dyn_cast<BitsPrimOp>(user);
3216 usedBits.set(bits.getLo(), bits.getHi() + 1);
3220 mapping[bits.getLo()] = 0;
3221 readOps.push_back(bits);
3231 SmallVector<MatchingConnectOp> writeOps;
3232 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
3233 auto portTy = type_cast<BundleType>(port.getType());
3234 auto fieldIndex = portTy.getElementIndex(field);
3235 assert(fieldIndex &&
"missing data port");
3237 for (
auto *op : port.getUsers()) {
3238 auto portAccess = cast<SubfieldOp>(op);
3239 if (fieldIndex != portAccess.getFieldIndex())
3246 writeOps.push_back(conn);
3252 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3254 if (!mem.getPortAnnotation(i).empty())
3257 switch (mem.getPortKind(i)) {
3258 case MemOp::PortKind::Debug:
3261 case MemOp::PortKind::Write:
3262 if (failed(findWriteUsers(port,
"data")))
3265 case MemOp::PortKind::Read:
3266 if (failed(findReadUsers(port,
"data")))
3269 case MemOp::PortKind::ReadWrite:
3270 if (failed(findWriteUsers(port,
"wdata")))
3272 if (failed(findReadUsers(port,
"rdata")))
3276 llvm_unreachable(
"unknown port kind");
3280 if (usedBits.none())
3284 SmallVector<std::pair<unsigned, unsigned>> ranges;
3285 unsigned newWidth = 0;
3286 for (
int i = usedBits.find_first(); 0 <= i && i < width;) {
3287 int e = usedBits.find_next_unset(i);
3290 for (
int idx = i; idx < e; ++idx, ++newWidth) {
3291 if (
auto it = mapping.find(idx); it != mapping.end()) {
3292 it->second = newWidth;
3295 ranges.emplace_back(i, e - 1);
3296 i = e != width ? usedBits.find_next(e) : e;
3300 auto newType =
IntType::get(mem->getContext(), type.isSigned(), newWidth);
3301 SmallVector<Type> portTypes;
3302 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3303 portTypes.push_back(
3304 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
3306 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
3307 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
3308 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
3309 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
3310 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3313 auto rewriteSubfield = [&](Value port, StringRef field) {
3314 auto portTy = type_cast<BundleType>(port.getType());
3315 auto fieldIndex = portTy.getElementIndex(field);
3316 assert(fieldIndex &&
"missing data port");
3318 rewriter.setInsertionPointAfter(newMem);
3319 auto newPortAccess =
3320 SubfieldOp::create(rewriter, port.getLoc(), port, field);
3322 for (
auto *op :
llvm::make_early_inc_range(port.getUsers())) {
3323 auto portAccess = cast<SubfieldOp>(op);
3324 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
3326 rewriter.replaceOp(portAccess, newPortAccess.getResult());
3331 for (
auto [i, port] :
llvm::enumerate(newMem.getResults())) {
3332 switch (newMem.getPortKind(i)) {
3333 case MemOp::PortKind::Debug:
3334 llvm_unreachable(
"cannot rewrite debug port");
3335 case MemOp::PortKind::Write:
3336 rewriteSubfield(port,
"data");
3338 case MemOp::PortKind::Read:
3339 rewriteSubfield(port,
"data");
3341 case MemOp::PortKind::ReadWrite:
3342 rewriteSubfield(port,
"rdata");
3343 rewriteSubfield(port,
"wdata");
3346 llvm_unreachable(
"unknown port kind");
3350 for (
auto readOp : readOps) {
3351 rewriter.setInsertionPointAfter(readOp);
3352 auto it = mapping.find(readOp.getLo());
3353 assert(it != mapping.end() &&
"bit op mapping not found");
3356 auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
3357 readOp.getLoc(), readOp.getInput(),
3358 readOp.getHi() - readOp.getLo() + it->second, it->second);
3359 rewriter.replaceAllUsesWith(readOp, newReadValue);
3360 rewriter.eraseOp(readOp);
3364 for (
auto writeOp : writeOps) {
3365 Value source = writeOp.getSrc();
3366 rewriter.setInsertionPoint(writeOp);
3368 SmallVector<Value> slices;
3369 for (
auto &[start, end] :
llvm::reverse(ranges)) {
3370 Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
3371 source,
end, start);
3372 slices.push_back(slice);
3376 rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(), slices);
3382 if (type.isSigned())
3384 rewriter.createOrFold<AsSIntPrimOp>(writeOp.getLoc(), catOfSlices);
3386 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
3396 using OpRewritePattern::OpRewritePattern;
3397 LogicalResult matchAndRewrite(MemOp mem,
3398 PatternRewriter &rewriter)
const override {
3403 auto ty = mem.getDataType();
3404 auto loc = mem.getLoc();
3405 auto *block = mem->getBlock();
3409 SmallPtrSet<Operation *, 8> connects;
3410 SmallVector<SubfieldOp> portAccesses;
3411 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3412 if (!mem.getPortAnnotation(i).empty())
3415 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
3416 auto portTy = type_cast<BundleType>(port.getType());
3417 for (
auto field : fields) {
3418 auto fieldIndex = portTy.getElementIndex(field);
3419 assert(fieldIndex &&
"missing field on memory port");
3421 for (
auto *op : port.getUsers()) {
3422 auto portAccess = cast<SubfieldOp>(op);
3423 if (fieldIndex != portAccess.getFieldIndex())
3425 portAccesses.push_back(portAccess);
3426 for (
auto *user : portAccess->getUsers()) {
3427 auto conn = dyn_cast<FConnectLike>(user);
3430 connects.insert(conn);
3437 switch (mem.getPortKind(i)) {
3438 case MemOp::PortKind::Debug:
3440 case MemOp::PortKind::Read:
3441 if (failed(collect({
"clk",
"en",
"addr"})))
3444 case MemOp::PortKind::Write:
3445 if (failed(collect({
"clk",
"en",
"addr",
"data",
"mask"})))
3448 case MemOp::PortKind::ReadWrite:
3449 if (failed(collect({
"clk",
"en",
"addr",
"wmode",
"wdata",
"wmask"})))
3455 if (!portClock || (clock && portClock != clock))
3461 rewriter.setInsertionPointAfter(mem);
3462 auto memWire = WireOp::create(rewriter, loc, ty).getResult();
3468 rewriter.setInsertionPointToEnd(block);
3470 RegOp::create(rewriter, loc, ty, clock, mem.getName()).getResult();
3473 MatchingConnectOp::create(rewriter, loc, memWire, memReg);
3477 auto pipeline = [&](Value value, Value clock,
const Twine &name,
3479 for (
unsigned i = 0; i < latency; ++i) {
3480 std::string regName;
3482 llvm::raw_string_ostream os(regName);
3483 os << mem.getName() <<
"_" << name <<
"_" << i;
3485 auto reg = RegOp::create(rewriter, mem.getLoc(), value.getType(), clock,
3486 rewriter.getStringAttr(regName))
3488 MatchingConnectOp::create(rewriter, value.getLoc(), reg, value);
3494 const unsigned writeStages =
info.writeLatency - 1;
3499 SmallVector<std::tuple<Value, Value, Value>> writes;
3500 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3502 StringRef name = mem.getPortName(i);
3504 auto portPipeline = [&, port = port](StringRef field,
unsigned stages) {
3507 return pipeline(value, portClock, name +
"_" + field, stages);
3510 switch (mem.getPortKind(i)) {
3511 case MemOp::PortKind::Debug:
3512 llvm_unreachable(
"unknown port kind");
3513 case MemOp::PortKind::Read: {
3521 case MemOp::PortKind::Write: {
3522 auto data = portPipeline(
"data", writeStages);
3523 auto en = portPipeline(
"en", writeStages);
3524 auto mask = portPipeline(
"mask", writeStages);
3528 case MemOp::PortKind::ReadWrite: {
3533 auto wdata = portPipeline(
"wdata", writeStages);
3534 auto wmask = portPipeline(
"wmask", writeStages);
3539 auto wen = AndPrimOp::create(rewriter, port.getLoc(),
en,
wmode);
3541 pipeline(wen, portClock, name +
"_wen", writeStages);
3542 writes.emplace_back(
wdata, wenPipelined,
wmask);
3549 Value next = memReg;
3555 Location loc = mem.getLoc();
3556 unsigned maskGran =
info.dataWidth /
info.maskBits;
3557 SmallVector<Value> chunks;
3558 for (
unsigned i = 0; i <
info.maskBits; ++i) {
3559 unsigned hi = (i + 1) * maskGran - 1;
3560 unsigned lo = i * maskGran;
3562 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc,
data, hi, lo);
3563 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3564 auto bit = rewriter.createOrFold<BitsPrimOp>(loc,
mask, i, i);
3565 auto chunk = MuxPrimOp::create(rewriter, loc, bit, dataPart, nextPart);
3566 chunks.push_back(chunk);
3569 std::reverse(chunks.begin(), chunks.end());
3570 masked = rewriter.createOrFold<CatPrimOp>(loc, chunks);
3571 next = MuxPrimOp::create(rewriter, next.getLoc(),
en, masked, next);
3573 Value typedNext = rewriter.createOrFold<BitCastOp>(next.getLoc(), ty, next);
3574 MatchingConnectOp::create(rewriter, memReg.getLoc(), memReg, typedNext);
3577 for (Operation *conn : connects)
3578 rewriter.eraseOp(
conn);
3579 for (
auto portAccess : portAccesses)
3580 rewriter.eraseOp(portAccess);
3581 rewriter.eraseOp(mem);
3588void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3591 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3592 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3612 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3615 auto *high = mux.getHigh().getDefiningOp();
3616 auto *low = mux.getLow().getDefiningOp();
3618 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3625 bool constReg =
false;
3627 if (constOp && low == reg)
3629 else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3631 constOp = dyn_cast<ConstantOp>(low);
3638 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3642 auto regTy = reg.getResult().getType();
3643 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3644 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3645 regTy.getBitWidthOrSentinel() < 0)
3651 if (constOp != &con->getBlock()->front())
3652 constOp->moveBefore(&con->getBlock()->front());
3655 SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3656 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3657 rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3658 mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3659 reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3660 reg.getForceableAttr());
3661 newReg->setDialectAttrs(attrs);
3663 auto pt = rewriter.saveInsertionPoint();
3664 rewriter.setInsertionPoint(con);
3665 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3666 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3667 rewriter.restoreInsertionPoint(pt);
3671LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3672 if (!
hasDontTouch(op.getOperation()) && !op.isForceable() &&
3688 PatternRewriter &rewriter,
3691 if (
auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3692 if (constant.getValue().isZero()) {
3693 rewriter.eraseOp(op);
3699 if (
auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3700 if (constant.getValue().isZero() == eraseIfZero) {
3701 rewriter.eraseOp(op);
3709template <
class Op,
bool EraseIfZero = false>
3711 PatternRewriter &rewriter) {
3716void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3718 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3719 results.add<patterns::AssertXWhenX>(
context);
3722void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3724 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3725 results.add<patterns::AssumeXWhenX>(
context);
3728void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3729 RewritePatternSet &results, MLIRContext *
context) {
3730 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3731 results.add<patterns::UnclockedAssumeIntrinsicXWhenX>(
context);
3734void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3736 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3743LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3744 PatternRewriter &rewriter) {
3746 if (op.use_empty()) {
3747 rewriter.eraseOp(op);
3754 if (op->hasOneUse() &&
3755 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3756 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3757 *op->user_begin()) ||
3758 (isa<CvtPrimOp>(*op->user_begin()) &&
3759 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3760 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3761 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3762 .getBitWidthOrSentinel() > 0))) {
3763 auto *modop = *op->user_begin();
3764 auto inv = InvalidValueOp::create(rewriter, op.getLoc(),
3765 modop->getResult(0).getType());
3766 rewriter.replaceAllOpUsesWith(modop, inv);
3767 rewriter.eraseOp(modop);
3768 rewriter.eraseOp(op);
3774OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3775 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3776 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3784OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3793 return BoolAttr::get(getContext(),
false);
3797 return BoolAttr::get(getContext(),
false);
3802LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3803 PatternRewriter &rewriter) {
3805 if (
auto testEnable = op.getTestEnable()) {
3806 if (
auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3807 if (constOp.getValue().isZero()) {
3808 rewriter.modifyOpInPlace(op,
3809 [&] { op.getTestEnableMutable().clear(); });
3825 auto forceable = op.getRef().getDefiningOp<Forceable>();
3826 if (!forceable || !forceable.isForceable() ||
3827 op.getRef() != forceable.getDataRef() ||
3828 op.getType() != forceable.getDataType())
3830 rewriter.replaceAllUsesWith(op, forceable.getData());
3834void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3836 results.insert<patterns::RefResolveOfRefSend>(
context);
3840OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3842 if (getInput().getType() == getType())
3848 auto constOp = operand.getDefiningOp<ConstantOp>();
3849 return constOp && constOp.getValue().isZero();
3852template <
typename Op>
3855 rewriter.eraseOp(op);
3861void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3863 results.add(eraseIfPredFalse<RefForceOp>);
3865void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3867 results.add(eraseIfPredFalse<RefForceInitialOp>);
3869void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3871 results.add(eraseIfPredFalse<RefReleaseOp>);
3873void RefReleaseInitialOp::getCanonicalizationPatterns(
3874 RewritePatternSet &results, MLIRContext *
context) {
3875 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3882OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3888 if (adaptor.getReset())
3893 if (
isUInt1(getReset().getType()) && adaptor.getClock())
3906 [&](
auto ty) ->
bool {
return isTypeEmpty(ty.getElementType()); })
3907 .Case<BundleType>([&](
auto ty) ->
bool {
3908 for (
auto elem : ty.getElements())
3913 .Case<IntType>([&](
auto ty) {
return ty.getWidth() == 0; })
3914 .Default([](
auto) ->
bool {
return false; });
3917LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3918 PatternRewriter &rewriter) {
3919 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3926 rewriter.eraseOp(op);
3934LogicalResult LayerBlockOp::canonicalize(LayerBlockOp op,
3935 PatternRewriter &rewriter) {
3938 if (op.getBody()->empty()) {
3939 rewriter.eraseOp(op);
3950OpFoldResult UnsafeDomainCastOp::fold(FoldAdaptor adaptor) {
3952 if (getDomains().
empty())
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion)
Replaces the given op with the contents of the given single-block region.
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static std::optional< bool > getBoolValue(Attribute attr)
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static bool isOkToPropagateName(Operation *op)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
static Location getLoc(DefSlot slot)
static InstancePath empty
AndRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
OrRCat(MLIRContext *context)
bool getIdentityValue() const override
Return the unit value for this reduction operation:
virtual bool getIdentityValue() const =0
Return the unit value for this reduction operation:
virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp constantOp, SmallVectorImpl< Value > &remaining) const =0
Handle a constant operand in the cat operation.
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
XorRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
LogicalResult matchAndRewrite(BitsPrimOp bits, mlir::PatternRewriter &rewriter) const override