20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/ADT/APSInt.h"
23#include "llvm/ADT/SmallPtrSet.h"
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/ADT/TypeSwitch.h"
28using namespace firrtl;
32static Value
dropWrite(PatternRewriter &rewriter, OpResult old,
34 SmallPtrSet<Operation *, 8> users;
35 for (
auto *user : old.getUsers())
37 for (Operation *user : users)
38 if (
auto connect = dyn_cast<FConnectLike>(user))
39 if (connect.getDest() == old)
40 rewriter.eraseOp(user);
50 if (op->getNumRegions() != 0)
52 return mlir::isPure(op) || isa<NodeOp, WireOp>(op);
60 Operation *op = passthrough.getDefiningOp();
63 assert(op &&
"passthrough must be an operation");
64 Operation *oldOp = old.getOwner();
65 auto name = oldOp->getAttrOfType<StringAttr>(
"name");
67 op->setAttr(
"name", name);
75#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
83 auto resultType = type_cast<IntType>(op->getResult(0).getType());
84 if (!resultType.hasWidth())
86 for (Value operand : op->getOperands())
87 if (!type_cast<IntType>(operand.getType()).hasWidth())
94 auto t = type_dyn_cast<UIntType>(type);
95 if (!t || !t.hasWidth() || t.getWidth() != 1)
102static void updateName(PatternRewriter &rewriter, Operation *op,
107 assert((!isa<InstanceOp, RegOp, RegResetOp>(op)) &&
"Should never rename");
108 auto newName = name.getValue();
109 auto newOpName = op->getAttrOfType<StringAttr>(
"name");
112 newName =
chooseName(newOpName.getValue(), name.getValue());
114 if (!newOpName || newOpName.getValue() != newName)
115 rewriter.modifyOpInPlace(
116 op, [&] { op->setAttr(
"name", rewriter.getStringAttr(newName)); });
124 if (
auto *newOp = newValue.getDefiningOp()) {
125 auto name = op->getAttrOfType<StringAttr>(
"name");
128 rewriter.replaceOp(op, newValue);
134template <
typename OpTy,
typename... Args>
136 Operation *op, Args &&...args) {
137 auto name = op->getAttrOfType<StringAttr>(
"name");
139 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
147 if (
auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
148 return namableOp.hasDroppableName();
159static std::optional<APSInt>
161 assert(type_cast<IntType>(operand.getType()) &&
162 "getExtendedConstant is limited to integer types");
169 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
174 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
175 return APSInt(destWidth,
176 type_cast<IntType>(operand.getType()).isUnsigned());
184 if (
auto attr = dyn_cast<BoolAttr>(operand))
185 return APSInt(APInt(1, attr.getValue()));
186 if (
auto attr = dyn_cast<IntegerAttr>(operand))
187 return attr.getAPSInt();
195 return cst->isZero();
222 Operation *op, ArrayRef<Attribute> operands,
BinOpKind opKind,
223 const function_ref<APInt(
const APSInt &,
const APSInt &)> &calculate) {
224 assert(operands.size() == 2 &&
"binary op takes two operands");
227 auto resultType = type_cast<IntType>(op->getResult(0).getType());
228 if (resultType.getWidthOrSentinel() < 0)
232 if (resultType.getWidthOrSentinel() == 0)
233 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
239 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
241 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
242 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
243 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
244 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
245 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
249 int32_t operandWidth;
252 operandWidth = resultType.getWidthOrSentinel();
257 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
261 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
272 APInt resultValue = calculate(*lhs, *rhs);
277 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
279 assert((
unsigned)resultType.getWidthOrSentinel() ==
280 resultValue.getBitWidth());
293 Operation *op, PatternRewriter &rewriter,
294 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
296 if (op->getNumResults() != 1)
298 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
303 auto width = type.getBitWidthOrSentinel();
308 SmallVector<Attribute, 3> constOperands;
309 constOperands.reserve(op->getNumOperands());
310 for (
auto operand : op->getOperands()) {
312 if (
auto *defOp = operand.getDefiningOp())
313 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
314 [&](
auto op) { attr = op.getValueAttr(); });
315 constOperands.push_back(attr);
320 auto result = canonicalize(constOperands);
324 if (
auto cst = dyn_cast<Attribute>(result))
325 resultValue = op->getDialect()
326 ->materializeConstant(rewriter, cst, type, op->getLoc())
329 resultValue = cast<Value>(result);
333 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
334 resultValue = PadPrimOp::create(rewriter, op->getLoc(), resultValue, width);
337 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
338 resultValue = AsSIntPrimOp::create(rewriter, op->getLoc(), resultValue);
339 else if (type_isa<UIntType>(type) &&
340 type_isa<SIntType>(resultValue.getType()))
341 resultValue = AsUIntPrimOp::create(rewriter, op->getLoc(), resultValue);
343 assert(type == resultValue.getType() &&
"canonicalization changed type");
351 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
357 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
363 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
370OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
371 assert(adaptor.getOperands().empty() &&
"constant has no operands");
372 return getValueAttr();
375OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
376 assert(adaptor.getOperands().empty() &&
"constant has no operands");
377 return getValueAttr();
380OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
381 assert(adaptor.getOperands().empty() &&
"constant has no operands");
382 return getFieldsAttr();
385OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
386 assert(adaptor.getOperands().empty() &&
"constant has no operands");
387 return getValueAttr();
390OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
391 assert(adaptor.getOperands().empty() &&
"constant has no operands");
392 return getValueAttr();
395OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
396 assert(adaptor.getOperands().empty() &&
"constant has no operands");
397 return getValueAttr();
400OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
401 assert(adaptor.getOperands().empty() &&
"constant has no operands");
402 return getValueAttr();
409OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
412 [=](
const APSInt &a,
const APSInt &b) { return a + b; });
415void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
417 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
418 patterns::AddOfSelf, patterns::AddOfPad>(
context);
421OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
424 [=](
const APSInt &a,
const APSInt &b) { return a - b; });
427void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
429 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
430 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
431 patterns::SubOfPadL, patterns::SubOfPadR>(
context);
434OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
446 [=](
const APSInt &a,
const APSInt &b) { return a * b; });
449OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
456 if (getLhs() == getRhs()) {
457 auto width = getType().base().getWidthOrSentinel();
462 return getIntAttr(getType(), APInt(width, 1));
479 if (
auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
480 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
485 [=](
const APSInt &a,
const APSInt &b) -> APInt {
488 return APInt(a.getBitWidth(), 0);
492OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
499 if (getLhs() == getRhs())
513 [=](
const APSInt &a,
const APSInt &b) -> APInt {
516 return APInt(a.getBitWidth(), 0);
520OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
523 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
526OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
529 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
532OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
535 [=](
const APSInt &a,
const APSInt &b) -> APInt {
536 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
542OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
545 if (rhsCst->isZero())
549 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
550 getRhs().getType() == getType())
556 if (lhsCst->isZero())
560 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
561 getRhs().getType() == getType())
566 if (getLhs() == getRhs() && getRhs().getType() == getType())
571 [](
const APSInt &a,
const APSInt &b) -> APInt { return a & b; });
574void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
577 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
578 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
579 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(
context);
582OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
585 if (rhsCst->isZero() && getLhs().getType() == getType())
589 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
590 getLhs().getType() == getType())
596 if (lhsCst->isZero() && getRhs().getType() == getType())
600 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
601 getRhs().getType() == getType())
606 if (getLhs() == getRhs() && getRhs().getType() == getType())
611 [](
const APSInt &a,
const APSInt &b) -> APInt { return a | b; });
614void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
616 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
617 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
621OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
624 if (rhsCst->isZero() &&
630 if (lhsCst->isZero() &&
635 if (getLhs() == getRhs())
638 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
642 [](
const APSInt &a,
const APSInt &b) -> APInt { return a ^ b; });
645void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
647 results.insert<patterns::extendXor, patterns::moveConstXor,
648 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
652void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
654 results.insert<patterns::LEQWithConstLHS>(
context);
657OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
658 bool isUnsigned = getLhs().getType().base().isUnsigned();
661 if (getLhs() == getRhs())
665 if (
auto width = getLhs().getType().base().
getWidth()) {
667 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
668 commonWidth = std::max(commonWidth, 1);
679 if (isUnsigned && rhsCst->zext(commonWidth)
692 [=](
const APSInt &a,
const APSInt &b) -> APInt {
693 return APInt(1, a <= b);
697void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
699 results.insert<patterns::LTWithConstLHS>(
context);
702OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
703 IntType lhsType = getLhs().getType();
707 if (getLhs() == getRhs())
717 if (
auto width = lhsType.
getWidth()) {
719 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
720 commonWidth = std::max(commonWidth, 1);
731 if (isUnsigned && rhsCst->zext(commonWidth)
744 [=](
const APSInt &a,
const APSInt &b) -> APInt {
745 return APInt(1, a < b);
749void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
751 results.insert<patterns::GEQWithConstLHS>(
context);
754OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
755 IntType lhsType = getLhs().getType();
759 if (getLhs() == getRhs())
764 if (rhsCst->isZero() && isUnsigned)
769 if (
auto width = lhsType.
getWidth()) {
771 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
772 commonWidth = std::max(commonWidth, 1);
775 if (isUnsigned && rhsCst->zext(commonWidth)
796 [=](
const APSInt &a,
const APSInt &b) -> APInt {
797 return APInt(1, a >= b);
801void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
803 results.insert<patterns::GTWithConstLHS>(
context);
806OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
807 IntType lhsType = getLhs().getType();
811 if (getLhs() == getRhs())
815 if (
auto width = lhsType.
getWidth()) {
817 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
818 commonWidth = std::max(commonWidth, 1);
821 if (isUnsigned && rhsCst->zext(commonWidth)
842 [=](
const APSInt &a,
const APSInt &b) -> APInt {
843 return APInt(1, a > b);
847OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
849 if (getLhs() == getRhs())
855 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
856 getRhs().getType() == getType())
862 [=](
const APSInt &a,
const APSInt &b) -> APInt {
863 return APInt(1, a == b);
867LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
869 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
871 auto width = op.getLhs().getType().getBitWidthOrSentinel();
874 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
875 op.getRhs().getType() == op.getType()) {
876 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
881 if (rhsCst->isZero() && width > 1) {
882 auto orrOp = OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
883 return NotPrimOp::create(rewriter, op.getLoc(), orrOp).getResult();
887 if (rhsCst->isAllOnes() && width > 1 &&
888 op.getLhs().getType() == op.getRhs().getType()) {
889 return AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
897OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
899 if (getLhs() == getRhs())
905 if (rhsCst->isZero() && getLhs().getType() == getType() &&
906 getRhs().getType() == getType())
912 [=](
const APSInt &a,
const APSInt &b) -> APInt {
913 return APInt(1, a != b);
917LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
919 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
921 auto width = op.getLhs().getType().getBitWidthOrSentinel();
924 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
925 op.getRhs().getType() == op.getType()) {
926 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
931 if (rhsCst->isZero() && width > 1) {
932 return OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
937 if (rhsCst->isAllOnes() && width > 1 &&
938 op.getLhs().getType() == op.getRhs().getType()) {
940 AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
941 return NotPrimOp::create(rewriter, op.getLoc(), andrOp).getResult();
949OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
955OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
961OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
965 return IntegerAttr::get(
966 IntegerType::get(getContext(), lhsCst->getBitWidth()),
967 lhsCst->ashr(*rhsCst));
970 if (rhsCst->isZero())
977OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
982 return IntegerAttr::get(
983 IntegerType::get(getContext(), lhsCst->getBitWidth()),
984 lhsCst->shl(*rhsCst));
987 if (rhsCst->isZero())
998OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
999 auto base = getInput().getType();
1006OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
1013OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
1021 if (getType().base().hasWidth())
1028void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1030 results.insert<patterns::StoUtoS>(
context);
1033OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1041 if (getType().base().hasWidth())
1048void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1050 results.insert<patterns::UtoStoU>(
context);
1053OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1055 if (getInput().getType() == getType())
1060 return BoolAttr::get(getContext(), cst->getBoolValue());
1065OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1067 if (getInput().getType() == getType())
1072 return BoolAttr::get(getContext(), cst->getBoolValue());
1077OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1083 getType().base().getWidthOrSentinel()))
1089void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1091 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(
context);
1094OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1101 getType().base().getWidthOrSentinel()))
1102 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1107OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1112 getType().base().getWidthOrSentinel()))
1118void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1120 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1121 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1128 : RewritePattern(opName, 0,
context) {}
1134 ConstantOp constantOp,
1135 SmallVectorImpl<Value> &remaining)
const = 0;
1142 mlir::PatternRewriter &rewriter)
const override {
1144 auto catOp = op->getOperand(0).getDefiningOp<CatPrimOp>();
1148 SmallVector<Value> nonConstantOperands;
1151 for (
auto operand : catOp.getInputs()) {
1152 if (
auto constantOp = operand.getDefiningOp<ConstantOp>()) {
1154 if (
handleConstant(rewriter, op, constantOp, nonConstantOperands))
1158 nonConstantOperands.push_back(operand);
1163 if (nonConstantOperands.empty()) {
1164 replaceOpWithNewOpAndCopyName<ConstantOp>(
1165 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1171 if (nonConstantOperands.size() == 1) {
1172 rewriter.modifyOpInPlace(
1173 op, [&] { op->setOperand(0, nonConstantOperands.front()); });
1178 if (catOp->hasOneUse() &&
1179 nonConstantOperands.size() < catOp->getNumOperands()) {
1180 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, catOp,
1181 nonConstantOperands);
1194 SmallVectorImpl<Value> &remaining)
const override {
1195 if (value.getValue().isZero())
1198 replaceOpWithNewOpAndCopyName<ConstantOp>(
1199 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1212 SmallVectorImpl<Value> &remaining)
const override {
1213 if (value.getValue().isAllOnes())
1216 replaceOpWithNewOpAndCopyName<ConstantOp>(
1217 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1230 SmallVectorImpl<Value> &remaining)
const override {
1231 if (value.getValue().isZero())
1233 remaining.push_back(value);
1239OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1243 if (getInput().getType().getBitWidthOrSentinel() == 0)
1248 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1252 if (
isUInt1(getInput().getType()))
1258void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1260 results.insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1261 patterns::AndRPadS, patterns::AndRCatAndR_left,
1265OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1269 if (getInput().getType().getBitWidthOrSentinel() == 0)
1274 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1278 if (
isUInt1(getInput().getType()))
1284void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1286 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1287 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right,
OrRCat>(
1291OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1295 if (getInput().getType().getBitWidthOrSentinel() == 0)
1300 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1303 if (
isUInt1(getInput().getType()))
1309void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1312 .insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1313 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right,
XorRCat>(
1321OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1322 auto inputs = getInputs();
1323 auto inputAdaptors = adaptor.getInputs();
1330 if (inputs.size() == 1 && inputs[0].getType() == getType())
1338 SmallVector<Value> nonZeroInputs;
1339 SmallVector<Attribute> nonZeroAttributes;
1340 bool allConstant =
true;
1341 for (
auto [input, attr] :
llvm::zip(inputs, inputAdaptors)) {
1342 auto inputType = type_cast<IntType>(input.getType());
1343 if (inputType.getBitWidthOrSentinel() != 0) {
1344 nonZeroInputs.push_back(input);
1346 allConstant =
false;
1347 if (nonZeroInputs.size() > 1 && !allConstant)
1353 if (nonZeroInputs.empty())
1357 if (nonZeroInputs.size() == 1 && nonZeroInputs[0].getType() == getType())
1358 return nonZeroInputs[0];
1364 SmallVector<APInt> constants;
1365 for (
auto inputAdaptor : inputAdaptors) {
1367 constants.push_back(*cst);
1372 assert(!constants.empty());
1374 APInt result = constants[0];
1375 for (
size_t i = 1; i < constants.size(); ++i)
1376 result = result.concat(constants[i]);
1381void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1383 results.insert<patterns::DShlOfConstant>(
context);
1386void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1388 results.insert<patterns::DShrOfConstant>(
context);
1394class FlattenCat :
public mlir::RewritePattern {
1396 FlattenCat(MLIRContext *
context)
1397 : RewritePattern(CatPrimOp::getOperationName(), 0,
context) {}
1400 matchAndRewrite(Operation *op,
1401 mlir::PatternRewriter &rewriter)
const override {
1402 auto cat = cast<CatPrimOp>(op);
1404 cat.getType().getBitWidthOrSentinel() == 0)
1408 if (cat->hasOneUse() && isa<CatPrimOp>(*cat->getUsers().begin()))
1412 SmallVector<Value> operands;
1413 SmallVector<Value> worklist;
1414 auto pushOperands = [&worklist](CatPrimOp op) {
1415 for (
auto operand :
llvm::reverse(op.getInputs()))
1416 worklist.push_back(operand);
1419 bool hasSigned =
false, hasUnsigned =
false;
1420 while (!worklist.empty()) {
1421 auto value = worklist.pop_back_val();
1422 auto catOp = value.getDefiningOp<CatPrimOp>();
1424 operands.push_back(value);
1425 (type_isa<UIntType>(value.getType()) ? hasUnsigned : hasSigned) =
true;
1429 pushOperands(catOp);
1434 auto castToUIntIfSigned = [&](Value value) -> Value {
1435 if (type_isa<UIntType>(value.getType()))
1437 return AsUIntPrimOp::create(rewriter, value.getLoc(), value);
1440 assert(operands.size() >= 1 &&
"zero width cast must be rejected");
1442 if (operands.size() == 1) {
1443 rewriter.replaceOp(op, castToUIntIfSigned(operands[0]));
1447 if (operands.size() == cat->getNumOperands())
1451 if (hasSigned && hasUnsigned)
1452 for (
auto &operand : operands)
1453 operand = castToUIntIfSigned(operand);
1455 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, op, cat.getType(),
1462class CatOfConstant :
public mlir::RewritePattern {
1464 CatOfConstant(MLIRContext *
context)
1465 : RewritePattern(CatPrimOp::getOperationName(), 0,
context) {}
1468 matchAndRewrite(Operation *op,
1469 mlir::PatternRewriter &rewriter)
const override {
1470 auto cat = cast<CatPrimOp>(op);
1474 SmallVector<Value> operands;
1476 for (
size_t i = 0; i < cat->getNumOperands(); ++i) {
1477 auto cst = cat.getInputs()[i].getDefiningOp<ConstantOp>();
1479 operands.push_back(cat.getInputs()[i]);
1482 APSInt value = cst.getValue();
1484 for (; j < cat->getNumOperands(); ++j) {
1485 auto nextCst = cat.getInputs()[j].getDefiningOp<ConstantOp>();
1488 value = value.concat(nextCst.getValue());
1493 operands.push_back(cst);
1496 operands.push_back(ConstantOp::create(rewriter, cat.getLoc(), value));
1502 if (operands.size() == cat->getNumOperands())
1505 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, op, cat.getType(),
1514void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1516 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1517 patterns::CatCast, FlattenCat, CatOfConstant>(
context);
1520OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1523 if (op.getType() == op.getInput().getType())
1524 return op.getInput();
1528 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1529 if (op.getType() == in.getInput().getType())
1530 return in.getInput();
1535OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1536 IntType inputType = getInput().getType();
1537 IntType resultType = getType();
1539 if (inputType == getType() && resultType.
hasWidth())
1546 cst->extractBits(getHi() - getLo() + 1, getLo()));
1553 : RewritePattern(BitsPrimOp::getOperationName(), 0,
context) {}
1557 mlir::PatternRewriter &rewriter)
const override {
1558 auto bits = cast<BitsPrimOp>(op);
1559 auto cat = bits.getInput().getDefiningOp<CatPrimOp>();
1562 int32_t bitPos = bits.getLo();
1563 auto resultWidth = type_cast<UIntType>(bits.getType()).getWidthOrSentinel();
1564 if (resultWidth < 0)
1566 for (
auto operand : llvm::reverse(cat.getInputs())) {
1568 type_cast<IntType>(operand.getType()).getWidthOrSentinel();
1569 if (operandWidth < 0)
1571 if (bitPos < operandWidth) {
1572 if (bitPos + resultWidth <= operandWidth) {
1573 auto newBits = rewriter.createOrFold<BitsPrimOp>(
1574 op->getLoc(), operand, bitPos + resultWidth - 1, bitPos);
1580 bitPos -= operandWidth;
1586void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1589 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1597 unsigned loBit, PatternRewriter &rewriter) {
1598 auto resType = type_cast<IntType>(op->getResult(0).getType());
1599 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1600 value = BitsPrimOp::create(rewriter, op->getLoc(), value, hiBit, loBit);
1602 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1603 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1604 }
else if (resType.isUnsigned() &&
1605 !type_cast<IntType>(value.getType()).isUnsigned()) {
1606 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1608 rewriter.replaceOp(op, value);
1611template <
typename OpTy>
1612static OpFoldResult
foldMux(OpTy op,
typename OpTy::FoldAdaptor adaptor) {
1614 if (op.getType().getBitWidthOrSentinel() == 0)
1616 APInt(0, 0, op.getType().isSignedInteger()));
1619 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1620 return op.getHigh();
1625 if (op.getType().getBitWidthOrSentinel() < 0)
1630 if (cond->isZero() && op.getLow().getType() == op.getType())
1632 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1633 return op.getHigh();
1637 if (
auto lowCst =
getConstant(adaptor.getLow())) {
1639 if (
auto highCst =
getConstant(adaptor.getHigh())) {
1641 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1642 *highCst == *lowCst)
1645 if (highCst->isOne() && lowCst->isZero() &&
1646 op.getType() == op.getSel().getType())
1659OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1660 return foldMux(*
this, adaptor);
1663OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1664 return foldMux(*
this, adaptor);
1667OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) {
return {}; }
1674class MuxPad :
public mlir::RewritePattern {
1677 : RewritePattern(MuxPrimOp::getOperationName(), 0,
context) {}
1680 matchAndRewrite(Operation *op,
1681 mlir::PatternRewriter &rewriter)
const override {
1682 auto mux = cast<MuxPrimOp>(op);
1683 auto width = mux.getType().getBitWidthOrSentinel();
1687 auto pad = [&](Value input) -> Value {
1689 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1690 if (inputWidth < 0 || width == inputWidth)
1692 return PadPrimOp::create(rewriter, mux.getLoc(), mux.getType(), input,
1697 auto newHigh = pad(mux.getHigh());
1698 auto newLow = pad(mux.getLow());
1699 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1702 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1703 rewriter, op, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1711class MuxSharedCond :
public mlir::RewritePattern {
1713 MuxSharedCond(MLIRContext *
context)
1714 : RewritePattern(MuxPrimOp::getOperationName(), 0,
context) {}
1716 static const int depthLimit = 5;
1718 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1719 mlir::PatternRewriter &rewriter,
1720 bool updateInPlace)
const {
1721 if (updateInPlace) {
1722 rewriter.modifyOpInPlace(mux, [&] {
1723 mux.setOperand(1, high);
1724 mux.setOperand(2, low);
1728 rewriter.setInsertionPointAfter(mux);
1729 return MuxPrimOp::create(rewriter, mux.getLoc(), mux.getType(),
1730 ValueRange{mux.getSel(), high, low})
1735 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1736 bool updateInPlace,
int limit)
const {
1737 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1740 if (mux.getSel() == cond)
1741 return mux.getHigh();
1742 if (limit > depthLimit)
1744 updateInPlace &= mux->hasOneUse();
1746 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1748 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1751 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1752 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1757 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1758 bool updateInPlace,
int limit)
const {
1759 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1762 if (mux.getSel() == cond)
1763 return mux.getLow();
1764 if (limit > depthLimit)
1766 updateInPlace &= mux->hasOneUse();
1768 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1770 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1772 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1774 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1780 matchAndRewrite(Operation *op,
1781 mlir::PatternRewriter &rewriter)
const override {
1782 auto mux = cast<MuxPrimOp>(op);
1783 auto width = mux.getType().getBitWidthOrSentinel();
1787 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter,
true, 0)) {
1788 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1792 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter,
true, 0)) {
1793 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1802void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1805 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1806 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
1807 patterns::MuxSameTrue, patterns::MuxSameFalse,
1808 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
1812void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
1813 RewritePatternSet &results, MLIRContext *
context) {
1814 results.add<patterns::Mux2PadSel>(
context);
1817void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
1818 RewritePatternSet &results, MLIRContext *
context) {
1819 results.add<patterns::Mux4PadSel>(
context);
1822OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1823 auto input = this->getInput();
1826 if (input.getType() == getType())
1830 auto inputType = input.getType().base();
1837 auto destWidth = getType().base().getWidthOrSentinel();
1838 if (destWidth == -1)
1841 if (inputType.
isSigned() && cst->getBitWidth())
1842 return getIntAttr(getType(), cst->sext(destWidth));
1843 return getIntAttr(getType(), cst->zext(destWidth));
1849OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1850 auto input = this->getInput();
1851 IntType inputType = input.getType();
1852 int shiftAmount = getAmount();
1855 if (shiftAmount == 0)
1861 if (inputWidth != -1) {
1862 auto resultWidth = inputWidth + shiftAmount;
1863 shiftAmount = std::min(shiftAmount, resultWidth);
1864 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
1870OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
1871 auto input = this->getInput();
1872 IntType inputType = input.getType();
1873 int shiftAmount = getAmount();
1879 if (shiftAmount == 0 && inputWidth > 0)
1882 if (inputWidth == -1)
1884 if (inputWidth == 0)
1889 if (shiftAmount >= inputWidth && inputType.
isUnsigned())
1890 return getIntAttr(getType(), APInt(0, 0,
false));
1896 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
1898 value = cst->lshr(std::min(shiftAmount, inputWidth));
1899 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
1900 return getIntAttr(getType(), value.trunc(resultWidth));
1905LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
1906 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1907 if (inputWidth <= 0)
1911 unsigned shiftAmount = op.getAmount();
1912 if (
int(shiftAmount) >= inputWidth) {
1914 if (op.getType().base().isUnsigned())
1920 shiftAmount = inputWidth - 1;
1923 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
1927LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
1928 PatternRewriter &rewriter) {
1929 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1930 if (inputWidth <= 0)
1934 unsigned keepAmount = op.getAmount();
1936 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
1941OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
1945 getInput().getType().base().getWidthOrSentinel() - getAmount();
1946 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
1952OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
1956 cst->trunc(getType().base().getWidthOrSentinel()));
1960LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
1961 PatternRewriter &rewriter) {
1962 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1963 if (inputWidth <= 0)
1967 unsigned dropAmount = op.getAmount();
1968 if (dropAmount !=
unsigned(inputWidth))
1974void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
1976 results.add<patterns::SubaccessOfConstant>(
context);
1979OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
1981 if (adaptor.getInputs().size() == 1)
1982 return getOperand(1);
1984 if (
auto constIndex =
getConstant(adaptor.getIndex())) {
1985 auto index = constIndex->getZExtValue();
1986 if (index < getInputs().size())
1987 return getInputs()[getInputs().size() - 1 - index];
1993LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
1994 PatternRewriter &rewriter) {
1998 if (llvm::all_of(op.getInputs().drop_front(), [&](
auto input) {
1999 return input == op.getInputs().front();
2007 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
2008 uint64_t inputSize = op.getInputs().size();
2009 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
2010 rewriter.modifyOpInPlace(op, [&]() {
2011 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
2018 if (
auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
2019 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](
auto e) {
2020 auto subindex = e.value().template getDefiningOp<SubindexOp>();
2021 return subindex && lastSubindex.getInput() == subindex.getInput() &&
2022 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
2024 replaceOpWithNewOpAndCopyName<SubaccessOp>(
2025 rewriter, op, lastSubindex.getInput(), op.getIndex());
2031 if (op.getInputs().size() != 2)
2035 auto uintType = op.getIndex().getType();
2036 if (uintType.getBitWidthOrSentinel() != 1)
2040 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
2041 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
2060 MatchingConnectOp connect;
2061 for (Operation *user : value.getUsers()) {
2063 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
2066 if (
auto aConnect = dyn_cast<FConnectLike>(user))
2067 if (aConnect.getDest() == value) {
2068 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
2071 if (!matchingConnect || (connect && connect != matchingConnect) ||
2072 matchingConnect->getBlock() != value.getParentBlock())
2074 connect = matchingConnect;
2082 PatternRewriter &rewriter) {
2085 Operation *connectedDecl = op.getDest().getDefiningOp();
2090 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
2094 cast<Forceable>(connectedDecl).isForceable())
2102 if (connectedDecl->hasOneUse())
2106 auto *declBlock = connectedDecl->getBlock();
2107 auto *srcValueOp = op.getSrc().getDefiningOp();
2110 if (!isa<WireOp>(connectedDecl))
2116 if (!isa<ConstantOp>(srcValueOp))
2118 if (srcValueOp->getBlock() != declBlock)
2124 auto replacement = op.getSrc();
2127 if (srcValueOp && srcValueOp != &declBlock->front())
2128 srcValueOp->moveBefore(&declBlock->front());
2135 rewriter.eraseOp(op);
2139void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2141 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
2145LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
2146 PatternRewriter &rewriter) {
2163 for (
auto *user : value.getUsers()) {
2164 auto attach = dyn_cast<AttachOp>(user);
2165 if (!attach || attach == dominatedAttach)
2167 if (attach->isBeforeInBlock(dominatedAttach))
2173LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
2175 if (op.getNumOperands() <= 1) {
2176 rewriter.eraseOp(op);
2180 for (
auto operand : op.getOperands()) {
2187 SmallVector<Value> newOperands(op.getOperands());
2188 for (
auto newOperand : attach.getOperands())
2189 if (newOperand != operand)
2190 newOperands.push_back(newOperand);
2191 AttachOp::create(rewriter, op->getLoc(), newOperands);
2192 rewriter.eraseOp(attach);
2193 rewriter.eraseOp(op);
2201 if (
auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
2202 if (!
hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
2203 !wire.isForceable()) {
2204 SmallVector<Value> newOperands;
2205 for (
auto newOperand : op.getOperands())
2206 if (newOperand != operand)
2207 newOperands.push_back(newOperand);
2209 AttachOp::create(rewriter, op->getLoc(), newOperands);
2210 rewriter.eraseOp(op);
2211 rewriter.eraseOp(wire);
2222 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
2223 rewriter.inlineBlockBefore(®ion.front(), op, {});
2226LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
2227 if (
auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
2228 if (constant.getValue().isAllOnes())
2230 else if (op.hasElseRegion() && !op.getElseRegion().empty())
2233 rewriter.eraseOp(op);
2239 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
2240 op.getElseBlock().empty()) {
2241 rewriter.eraseBlock(&op.getElseBlock());
2248 if (!op.getThenBlock().empty())
2252 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
2253 rewriter.eraseOp(op);
2262struct FoldNodeName :
public mlir::RewritePattern {
2263 FoldNodeName(MLIRContext *
context)
2264 : RewritePattern(NodeOp::getOperationName(), 0,
context) {}
2265 LogicalResult matchAndRewrite(Operation *op,
2266 PatternRewriter &rewriter)
const override {
2267 auto node = cast<NodeOp>(op);
2268 auto name = node.getNameAttr();
2269 if (!node.hasDroppableName() || node.getInnerSym() ||
2272 auto *newOp = node.getInput().getDefiningOp();
2275 rewriter.replaceOp(node, node.getInput());
2281struct NodeBypass :
public mlir::RewritePattern {
2282 NodeBypass(MLIRContext *
context)
2283 : RewritePattern(NodeOp::getOperationName(), 0,
context) {}
2284 LogicalResult matchAndRewrite(Operation *op,
2285 PatternRewriter &rewriter)
const override {
2286 auto node = cast<NodeOp>(op);
2288 node.use_empty() || node.isForceable())
2290 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
2297template <
typename OpTy>
2299 PatternRewriter &rewriter) {
2300 if (!op.isForceable() || !op.getDataRef().use_empty())
2308LogicalResult NodeOp::fold(FoldAdaptor adaptor,
2309 SmallVectorImpl<OpFoldResult> &results) {
2318 if (!adaptor.getInput())
2321 results.push_back(adaptor.getInput());
2325void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2327 results.insert<FoldNodeName>(
context);
2328 results.add(demoteForceableIfUnused<NodeOp>);
2334struct AggOneShot :
public mlir::RewritePattern {
2335 AggOneShot(StringRef name, uint32_t weight, MLIRContext *
context)
2336 : RewritePattern(name, 0,
context) {}
2338 SmallVector<Value> getCompleteWrite(Operation *lhs)
const {
2339 auto lhsTy = lhs->getResult(0).getType();
2340 if (!type_isa<BundleType, FVectorType>(lhsTy))
2343 DenseMap<uint32_t, Value> fields;
2344 for (Operation *user : lhs->getResult(0).getUsers()) {
2345 if (user->getParentOp() != lhs->getParentOp())
2347 if (
auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2348 if (aConnect.getDest() == lhs->getResult(0))
2350 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2351 for (Operation *subuser : subField.getResult().getUsers()) {
2352 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2353 if (aConnect.getDest() == subField) {
2354 if (subuser->getParentOp() != lhs->getParentOp())
2356 if (fields.count(subField.getFieldIndex()))
2358 fields[subField.getFieldIndex()] = aConnect.getSrc();
2364 }
else if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2365 for (Operation *subuser : subIndex.getResult().getUsers()) {
2366 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2367 if (aConnect.getDest() == subIndex) {
2368 if (subuser->getParentOp() != lhs->getParentOp())
2370 if (fields.count(subIndex.getIndex()))
2372 fields[subIndex.getIndex()] = aConnect.getSrc();
2383 SmallVector<Value> values;
2384 uint32_t total = type_isa<BundleType>(lhsTy)
2385 ? type_cast<BundleType>(lhsTy).getNumElements()
2386 : type_cast<FVectorType>(lhsTy).getNumElements();
2387 for (uint32_t i = 0; i < total; ++i) {
2388 if (!fields.count(i))
2390 values.push_back(fields[i]);
2395 LogicalResult matchAndRewrite(Operation *op,
2396 PatternRewriter &rewriter)
const override {
2397 auto values = getCompleteWrite(op);
2400 rewriter.setInsertionPointToEnd(op->getBlock());
2401 auto dest = op->getResult(0);
2402 auto destType = dest.getType();
2405 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2408 Value newVal = type_isa<BundleType>(destType)
2409 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2411 : rewriter.createOrFold<VectorCreateOp>(
2412 op->
getLoc(), destType, values);
2413 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2414 for (Operation *user : dest.getUsers()) {
2415 if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2416 for (Operation *subuser :
2417 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2418 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2419 if (aConnect.getDest() == subIndex)
2420 rewriter.eraseOp(aConnect);
2421 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2422 for (Operation *subuser :
2423 llvm::make_early_inc_range(subField.getResult().getUsers()))
2424 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2425 if (aConnect.getDest() == subField)
2426 rewriter.eraseOp(aConnect);
2433struct WireAggOneShot :
public AggOneShot {
2434 WireAggOneShot(MLIRContext *
context)
2435 : AggOneShot(WireOp::getOperationName(), 0,
context) {}
2437struct SubindexAggOneShot :
public AggOneShot {
2438 SubindexAggOneShot(MLIRContext *
context)
2439 : AggOneShot(SubindexOp::getOperationName(), 0,
context) {}
2441struct SubfieldAggOneShot :
public AggOneShot {
2442 SubfieldAggOneShot(MLIRContext *
context)
2443 : AggOneShot(SubfieldOp::getOperationName(), 0,
context) {}
2447void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2449 results.insert<WireAggOneShot>(
context);
2450 results.add(demoteForceableIfUnused<WireOp>);
2453void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2455 results.insert<SubindexAggOneShot>(
context);
2458OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2459 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2462 return attr[getIndex()];
2465OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2466 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2469 auto index = getFieldIndex();
2473void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2475 results.insert<SubfieldAggOneShot>(
context);
2479 ArrayRef<Attribute> operands) {
2480 for (
auto operand : operands)
2483 return ArrayAttr::get(
context, operands);
2486OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2489 if (getNumOperands() > 0)
2490 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2491 if (first.getFieldIndex() == 0 &&
2492 first.getInput().getType() == getType() &&
2494 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2496 elem.value().
template getDefiningOp<SubfieldOp>();
2497 return subindex && subindex.getInput() == first.getInput() &&
2498 subindex.getFieldIndex() == elem.index();
2500 return first.getInput();
2505OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2508 if (getNumOperands() > 0)
2509 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2510 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2512 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2514 elem.value().
template getDefiningOp<SubindexOp>();
2515 return subindex && subindex.getInput() == first.getInput() &&
2516 subindex.getIndex() == elem.index();
2518 return first.getInput();
2523OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2524 if (getOperand().getType() == getType())
2525 return getOperand();
2532struct FoldResetMux :
public mlir::RewritePattern {
2533 FoldResetMux(MLIRContext *
context)
2534 : RewritePattern(RegResetOp::getOperationName(), 0,
context) {}
2535 LogicalResult matchAndRewrite(Operation *op,
2536 PatternRewriter &rewriter)
const override {
2537 auto reg = cast<RegResetOp>(op);
2539 dyn_cast_or_null<ConstantOp>(
reg.getResetValue().getDefiningOp());
2548 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2551 auto *high = mux.getHigh().getDefiningOp();
2552 auto *low = mux.getLow().getDefiningOp();
2553 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2555 if (constOp && low != reg)
2557 if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2558 constOp = dyn_cast<ConstantOp>(low);
2560 if (!constOp || constOp.getType() != reset.getType() ||
2561 constOp.getValue() != reset.getValue())
2565 auto regTy =
reg.getResult().getType();
2566 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2567 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2568 regTy.getBitWidthOrSentinel() < 0)
2574 if (constOp != &con->getBlock()->front())
2575 constOp->moveBefore(&con->getBlock()->front());
2580 rewriter.eraseOp(con);
2587 if (
auto c = v.getDefiningOp<ConstantOp>())
2588 return c.getValue().isOne();
2589 if (
auto sc = v.getDefiningOp<SpecialConstantOp>())
2590 return sc.getValue();
2599 auto resetValue = reg.getResetValue();
2600 if (reg.getType(0) != resetValue.getType())
2604 (void)
dropWrite(rewriter, reg->getResult(0), {});
2605 replaceOpWithNewOpAndCopyName<NodeOp>(
2606 rewriter, reg, resetValue, reg.getNameAttr(), reg.getNameKind(),
2607 reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2611void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2613 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(
context);
2615 results.add(demoteForceableIfUnused<RegResetOp>);
2620 auto portTy = type_cast<BundleType>(port.getType());
2621 auto fieldIndex = portTy.getElementIndex(name);
2622 assert(fieldIndex &&
"missing field on memory port");
2625 for (
auto *op : port.getUsers()) {
2626 auto portAccess = cast<SubfieldOp>(op);
2627 if (fieldIndex != portAccess.getFieldIndex())
2632 value = conn.getSrc();
2642 auto portConst = value.getDefiningOp<ConstantOp>();
2645 return portConst.getValue().isZero();
2650 auto portTy = type_cast<BundleType>(port.getType());
2651 auto fieldIndex = portTy.getElementIndex(
data);
2652 assert(fieldIndex &&
"missing enable flag on memory port");
2654 for (
auto *op : port.getUsers()) {
2655 auto portAccess = cast<SubfieldOp>(op);
2656 if (fieldIndex != portAccess.getFieldIndex())
2658 if (!portAccess.use_empty())
2667 StringRef name, Value value) {
2668 auto portTy = type_cast<BundleType>(port.getType());
2669 auto fieldIndex = portTy.getElementIndex(name);
2670 assert(fieldIndex &&
"missing field on memory port");
2672 for (
auto *op : llvm::make_early_inc_range(port.getUsers())) {
2673 auto portAccess = cast<SubfieldOp>(op);
2674 if (fieldIndex != portAccess.getFieldIndex())
2676 rewriter.replaceAllUsesWith(portAccess, value);
2677 rewriter.eraseOp(portAccess);
2682static void erasePort(PatternRewriter &rewriter, Value port) {
2685 auto getClock = [&] {
2687 clock = SpecialConstantOp::create(rewriter, port.getLoc(),
2688 ClockType::get(rewriter.getContext()),
2697 for (
auto *op : port.getUsers()) {
2698 auto subfield = dyn_cast<SubfieldOp>(op);
2700 auto ty = port.getType();
2701 auto reg = RegOp::create(rewriter, port.getLoc(), ty, getClock());
2702 rewriter.replaceAllUsesWith(port, reg.getResult());
2711 for (
auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2712 auto access = cast<SubfieldOp>(accessOp);
2713 for (
auto *user : llvm::make_early_inc_range(access->getUsers())) {
2714 auto connect = dyn_cast<FConnectLike>(user);
2715 if (connect && connect.getDest() == access) {
2716 rewriter.eraseOp(user);
2720 if (access.use_empty()) {
2721 rewriter.eraseOp(access);
2727 auto ty = access.getType();
2728 auto reg = RegOp::create(rewriter, access.getLoc(), ty, getClock());
2729 rewriter.replaceOp(access, reg.getResult());
2731 assert(port.use_empty() &&
"port should have no remaining uses");
2736struct FoldZeroWidthMemory :
public mlir::RewritePattern {
2737 FoldZeroWidthMemory(MLIRContext *
context)
2738 : RewritePattern(MemOp::getOperationName(), 0,
context) {}
2739 LogicalResult matchAndRewrite(Operation *op,
2740 PatternRewriter &rewriter)
const override {
2741 MemOp mem = cast<MemOp>(op);
2745 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2746 mem.getDataType().getBitWidthOrSentinel() != 0)
2750 for (
auto port : mem.getResults())
2751 for (auto *user : port.getUsers())
2752 if (!isa<SubfieldOp>(user))
2757 for (
auto port : op->getResults()) {
2758 for (
auto *user :
llvm::make_early_inc_range(port.getUsers())) {
2759 SubfieldOp sfop = cast<SubfieldOp>(user);
2760 StringRef fieldName = sfop.getFieldName();
2761 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2762 rewriter, sfop, sfop.getResult().getType())
2764 if (fieldName.ends_with(
"data")) {
2766 auto zero = firrtl::ConstantOp::create(
2767 rewriter, wire.getLoc(),
2768 firrtl::type_cast<IntType>(wire.getType()), APInt::getZero(0));
2769 MatchingConnectOp::create(rewriter, wire.getLoc(), wire, zero);
2773 rewriter.eraseOp(op);
2779struct FoldReadOrWriteOnlyMemory :
public mlir::RewritePattern {
2780 FoldReadOrWriteOnlyMemory(MLIRContext *
context)
2781 : RewritePattern(MemOp::getOperationName(), 0,
context) {}
2782 LogicalResult matchAndRewrite(Operation *op,
2783 PatternRewriter &rewriter)
const override {
2784 MemOp mem = cast<MemOp>(op);
2787 bool isRead =
false, isWritten =
false;
2788 for (
unsigned i = 0; i < mem.getNumResults(); ++i) {
2789 switch (mem.getPortKind(i)) {
2790 case MemOp::PortKind::Read:
2795 case MemOp::PortKind::Write:
2800 case MemOp::PortKind::Debug:
2801 case MemOp::PortKind::ReadWrite:
2804 llvm_unreachable(
"unknown port kind");
2806 assert((!isWritten || !isRead) &&
"memory is in use");
2811 if (isRead && mem.getInit())
2814 for (
auto port : mem.getResults())
2817 rewriter.eraseOp(op);
2823struct FoldUnusedPorts :
public mlir::RewritePattern {
2824 FoldUnusedPorts(MLIRContext *
context)
2825 : RewritePattern(MemOp::getOperationName(), 0,
context) {}
2826 LogicalResult matchAndRewrite(Operation *op,
2827 PatternRewriter &rewriter)
const override {
2828 MemOp mem = cast<MemOp>(op);
2832 llvm::SmallBitVector deadPorts(mem.getNumResults());
2833 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2835 if (!mem.getPortAnnotation(i).empty())
2839 auto kind = mem.getPortKind(i);
2840 if (kind == MemOp::PortKind::Debug)
2849 if (kind == MemOp::PortKind::Read &&
isPortUnused(port,
"data")) {
2854 if (deadPorts.none())
2858 SmallVector<Type> resultTypes;
2859 SmallVector<StringRef> portNames;
2860 SmallVector<Attribute> portAnnotations;
2861 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2864 resultTypes.push_back(port.getType());
2865 portNames.push_back(mem.getPortName(i));
2866 portAnnotations.push_back(mem.getPortAnnotation(i));
2870 if (!resultTypes.empty())
2871 newOp = MemOp::create(
2872 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
2873 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2874 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2875 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2876 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2879 unsigned nextPort = 0;
2880 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2884 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
2887 rewriter.eraseOp(op);
2893struct FoldReadWritePorts :
public mlir::RewritePattern {
2894 FoldReadWritePorts(MLIRContext *
context)
2895 : RewritePattern(MemOp::getOperationName(), 0,
context) {}
2896 LogicalResult matchAndRewrite(Operation *op,
2897 PatternRewriter &rewriter)
const override {
2898 MemOp mem = cast<MemOp>(op);
2903 llvm::SmallBitVector deadReads(mem.getNumResults());
2904 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2905 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
2907 if (!mem.getPortAnnotation(i).empty())
2914 if (deadReads.none())
2917 SmallVector<Type> resultTypes;
2918 SmallVector<StringRef> portNames;
2919 SmallVector<Attribute> portAnnotations;
2920 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2922 resultTypes.push_back(
2923 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
2924 MemOp::PortKind::Write, mem.getMaskBits()));
2926 resultTypes.push_back(port.getType());
2928 portNames.push_back(mem.getPortName(i));
2929 portAnnotations.push_back(mem.getPortAnnotation(i));
2932 auto newOp = MemOp::create(
2933 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
2934 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2935 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2936 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2937 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2939 for (
unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
2940 auto result = mem.getResult(i);
2941 auto newResult = newOp.getResult(i);
2943 auto resultPortTy = type_cast<BundleType>(result.getType());
2947 auto replace = [&](StringRef toName, StringRef fromName) {
2948 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
2949 assert(fromFieldIndex &&
"missing enable flag on memory port");
2951 auto toField = SubfieldOp::create(rewriter, newResult.getLoc(),
2953 for (
auto *op :
llvm::make_early_inc_range(result.getUsers())) {
2954 auto fromField = cast<SubfieldOp>(op);
2955 if (fromFieldIndex != fromField.getFieldIndex())
2957 rewriter.replaceOp(fromField, toField.getResult());
2961 replace(
"addr",
"addr");
2962 replace(
"en",
"en");
2963 replace(
"clk",
"clk");
2964 replace(
"data",
"wdata");
2965 replace(
"mask",
"wmask");
2968 auto wmodeFieldIndex = resultPortTy.getElementIndex(
"wmode");
2969 for (
auto *op :
llvm::make_early_inc_range(result.getUsers())) {
2970 auto wmodeField = cast<SubfieldOp>(op);
2971 if (wmodeFieldIndex != wmodeField.getFieldIndex())
2973 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
2976 rewriter.replaceAllUsesWith(result, newResult);
2979 rewriter.eraseOp(op);
2985struct FoldUnusedBits :
public mlir::RewritePattern {
2986 FoldUnusedBits(MLIRContext *
context)
2987 : RewritePattern(MemOp::getOperationName(), 0,
context) {}
2989 LogicalResult matchAndRewrite(Operation *op,
2990 PatternRewriter &rewriter)
const override {
2991 MemOp mem = cast<MemOp>(op);
2996 const auto &summary = mem.getSummary();
2997 if (summary.isMasked || summary.isSeqMem())
3000 auto type = type_dyn_cast<IntType>(mem.getDataType());
3003 auto width = type.getBitWidthOrSentinel();
3007 llvm::SmallBitVector usedBits(width);
3008 DenseMap<unsigned, unsigned> mapping;
3013 SmallVector<BitsPrimOp> readOps;
3014 auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
3015 auto portTy = type_cast<BundleType>(port.getType());
3016 auto fieldIndex = portTy.getElementIndex(field);
3017 assert(fieldIndex &&
"missing data port");
3019 for (
auto *op : port.getUsers()) {
3020 auto portAccess = cast<SubfieldOp>(op);
3021 if (fieldIndex != portAccess.getFieldIndex())
3024 for (
auto *user : op->getUsers()) {
3025 auto bits = dyn_cast<BitsPrimOp>(user);
3029 usedBits.set(bits.getLo(), bits.getHi() + 1);
3033 mapping[bits.getLo()] = 0;
3034 readOps.push_back(bits);
3044 SmallVector<MatchingConnectOp> writeOps;
3045 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
3046 auto portTy = type_cast<BundleType>(port.getType());
3047 auto fieldIndex = portTy.getElementIndex(field);
3048 assert(fieldIndex &&
"missing data port");
3050 for (
auto *op : port.getUsers()) {
3051 auto portAccess = cast<SubfieldOp>(op);
3052 if (fieldIndex != portAccess.getFieldIndex())
3059 writeOps.push_back(conn);
3065 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3067 if (!mem.getPortAnnotation(i).empty())
3070 switch (mem.getPortKind(i)) {
3071 case MemOp::PortKind::Debug:
3074 case MemOp::PortKind::Write:
3075 if (failed(findWriteUsers(port,
"data")))
3078 case MemOp::PortKind::Read:
3079 if (failed(findReadUsers(port,
"data")))
3082 case MemOp::PortKind::ReadWrite:
3083 if (failed(findWriteUsers(port,
"wdata")))
3085 if (failed(findReadUsers(port,
"rdata")))
3089 llvm_unreachable(
"unknown port kind");
3093 if (usedBits.none())
3097 SmallVector<std::pair<unsigned, unsigned>> ranges;
3098 unsigned newWidth = 0;
3099 for (
int i = usedBits.find_first(); 0 <= i && i < width;) {
3100 int e = usedBits.find_next_unset(i);
3103 for (
int idx = i; idx < e; ++idx, ++newWidth) {
3104 if (
auto it = mapping.find(idx); it != mapping.end()) {
3105 it->second = newWidth;
3108 ranges.emplace_back(i, e - 1);
3109 i = e != width ? usedBits.find_next(e) : e;
3113 auto newType =
IntType::get(op->getContext(), type.isSigned(), newWidth);
3114 SmallVector<Type> portTypes;
3115 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3116 portTypes.push_back(
3117 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
3119 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
3120 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
3121 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
3122 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
3123 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3126 auto rewriteSubfield = [&](Value port, StringRef field) {
3127 auto portTy = type_cast<BundleType>(port.getType());
3128 auto fieldIndex = portTy.getElementIndex(field);
3129 assert(fieldIndex &&
"missing data port");
3131 rewriter.setInsertionPointAfter(newMem);
3132 auto newPortAccess =
3133 SubfieldOp::create(rewriter, port.getLoc(), port, field);
3135 for (
auto *op :
llvm::make_early_inc_range(port.getUsers())) {
3136 auto portAccess = cast<SubfieldOp>(op);
3137 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
3139 rewriter.replaceOp(portAccess, newPortAccess.getResult());
3144 for (
auto [i, port] :
llvm::enumerate(newMem.getResults())) {
3145 switch (newMem.getPortKind(i)) {
3146 case MemOp::PortKind::Debug:
3147 llvm_unreachable(
"cannot rewrite debug port");
3148 case MemOp::PortKind::Write:
3149 rewriteSubfield(port,
"data");
3151 case MemOp::PortKind::Read:
3152 rewriteSubfield(port,
"data");
3154 case MemOp::PortKind::ReadWrite:
3155 rewriteSubfield(port,
"rdata");
3156 rewriteSubfield(port,
"wdata");
3159 llvm_unreachable(
"unknown port kind");
3163 for (
auto readOp : readOps) {
3164 rewriter.setInsertionPointAfter(readOp);
3165 auto it = mapping.find(readOp.getLo());
3166 assert(it != mapping.end() &&
"bit op mapping not found");
3169 auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
3170 readOp.getLoc(), readOp.getInput(),
3171 readOp.getHi() - readOp.getLo() + it->second, it->second);
3172 rewriter.replaceAllUsesWith(readOp, newReadValue);
3173 rewriter.eraseOp(readOp);
3177 for (
auto writeOp : writeOps) {
3178 Value source = writeOp.getSrc();
3179 rewriter.setInsertionPoint(writeOp);
3181 SmallVector<Value> slices;
3182 for (
auto &[start, end] :
llvm::reverse(ranges)) {
3183 Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
3184 source,
end, start);
3185 slices.push_back(slice);
3189 rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(), slices);
3195 if (type.isSigned())
3197 rewriter.createOrFold<AsSIntPrimOp>(writeOp.getLoc(), catOfSlices);
3199 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
3208struct FoldRegMems :
public mlir::RewritePattern {
3209 FoldRegMems(MLIRContext *
context)
3210 : RewritePattern(MemOp::getOperationName(), 0,
context) {}
3211 LogicalResult matchAndRewrite(Operation *op,
3212 PatternRewriter &rewriter)
const override {
3213 MemOp mem = cast<MemOp>(op);
3218 auto ty = mem.getDataType();
3219 auto loc = mem.getLoc();
3220 auto *block = mem->getBlock();
3224 SmallPtrSet<Operation *, 8> connects;
3225 SmallVector<SubfieldOp> portAccesses;
3226 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3227 if (!mem.getPortAnnotation(i).empty())
3230 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
3231 auto portTy = type_cast<BundleType>(port.getType());
3232 for (
auto field : fields) {
3233 auto fieldIndex = portTy.getElementIndex(field);
3234 assert(fieldIndex &&
"missing field on memory port");
3236 for (
auto *op : port.getUsers()) {
3237 auto portAccess = cast<SubfieldOp>(op);
3238 if (fieldIndex != portAccess.getFieldIndex())
3240 portAccesses.push_back(portAccess);
3241 for (
auto *user : portAccess->getUsers()) {
3242 auto conn = dyn_cast<FConnectLike>(user);
3245 connects.insert(conn);
3252 switch (mem.getPortKind(i)) {
3253 case MemOp::PortKind::Debug:
3255 case MemOp::PortKind::Read:
3256 if (failed(collect({
"clk",
"en",
"addr"})))
3259 case MemOp::PortKind::Write:
3260 if (failed(collect({
"clk",
"en",
"addr",
"data",
"mask"})))
3263 case MemOp::PortKind::ReadWrite:
3264 if (failed(collect({
"clk",
"en",
"addr",
"wmode",
"wdata",
"wmask"})))
3270 if (!portClock || (clock && portClock != clock))
3276 rewriter.setInsertionPointAfter(mem);
3277 auto memWire = WireOp::create(rewriter, loc, ty).getResult();
3283 rewriter.setInsertionPointToEnd(block);
3285 RegOp::create(rewriter, loc, ty, clock, mem.getName()).getResult();
3288 MatchingConnectOp::create(rewriter, loc, memWire, memReg);
3292 auto pipeline = [&](Value value, Value clock,
const Twine &name,
3294 for (
unsigned i = 0; i < latency; ++i) {
3295 std::string regName;
3297 llvm::raw_string_ostream os(regName);
3298 os << mem.getName() <<
"_" << name <<
"_" << i;
3300 auto reg = RegOp::create(rewriter, mem.getLoc(), value.getType(), clock,
3301 rewriter.getStringAttr(regName))
3303 MatchingConnectOp::create(rewriter, value.getLoc(), reg, value);
3309 const unsigned writeStages =
info.writeLatency - 1;
3314 SmallVector<std::tuple<Value, Value, Value>> writes;
3315 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3317 StringRef name = mem.getPortName(i);
3319 auto portPipeline = [&, port = port](StringRef field,
unsigned stages) {
3322 return pipeline(value, portClock, name +
"_" + field, stages);
3325 switch (mem.getPortKind(i)) {
3326 case MemOp::PortKind::Debug:
3327 llvm_unreachable(
"unknown port kind");
3328 case MemOp::PortKind::Read: {
3336 case MemOp::PortKind::Write: {
3337 auto data = portPipeline(
"data", writeStages);
3338 auto en = portPipeline(
"en", writeStages);
3339 auto mask = portPipeline(
"mask", writeStages);
3343 case MemOp::PortKind::ReadWrite: {
3348 auto wdata = portPipeline(
"wdata", writeStages);
3349 auto wmask = portPipeline(
"wmask", writeStages);
3354 auto wen = AndPrimOp::create(rewriter, port.getLoc(),
en,
wmode);
3356 pipeline(wen, portClock, name +
"_wen", writeStages);
3357 writes.emplace_back(
wdata, wenPipelined,
wmask);
3364 Value next = memReg;
3370 Location loc = mem.getLoc();
3371 unsigned maskGran =
info.dataWidth /
info.maskBits;
3372 SmallVector<Value> chunks;
3373 for (
unsigned i = 0; i <
info.maskBits; ++i) {
3374 unsigned hi = (i + 1) * maskGran - 1;
3375 unsigned lo = i * maskGran;
3377 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc,
data, hi, lo);
3378 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3379 auto bit = rewriter.createOrFold<BitsPrimOp>(loc,
mask, i, i);
3380 auto chunk = MuxPrimOp::create(rewriter, loc, bit, dataPart, nextPart);
3381 chunks.push_back(chunk);
3384 std::reverse(chunks.begin(), chunks.end());
3385 masked = rewriter.createOrFold<CatPrimOp>(loc, chunks);
3386 next = MuxPrimOp::create(rewriter, next.getLoc(),
en, masked, next);
3388 Value typedNext = rewriter.createOrFold<BitCastOp>(next.getLoc(), ty, next);
3389 MatchingConnectOp::create(rewriter, memReg.getLoc(), memReg, typedNext);
3392 for (Operation *conn : connects)
3393 rewriter.eraseOp(conn);
3394 for (
auto portAccess : portAccesses)
3395 rewriter.eraseOp(portAccess);
3396 rewriter.eraseOp(mem);
3403void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3406 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3407 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3427 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3430 auto *high = mux.getHigh().getDefiningOp();
3431 auto *low = mux.getLow().getDefiningOp();
3433 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3440 bool constReg =
false;
3442 if (constOp && low == reg)
3444 else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3446 constOp = dyn_cast<ConstantOp>(low);
3453 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3457 auto regTy = reg.getResult().getType();
3458 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3459 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3460 regTy.getBitWidthOrSentinel() < 0)
3466 if (constOp != &con->getBlock()->front())
3467 constOp->moveBefore(&con->getBlock()->front());
3470 SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3471 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3472 rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3473 mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3474 reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3475 reg.getForceableAttr());
3476 newReg->setDialectAttrs(attrs);
3478 auto pt = rewriter.saveInsertionPoint();
3479 rewriter.setInsertionPoint(con);
3480 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3481 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3482 rewriter.restoreInsertionPoint(pt);
3486LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3487 if (!
hasDontTouch(op.getOperation()) && !op.isForceable() &&
3503 PatternRewriter &rewriter,
3506 if (
auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3507 if (constant.getValue().isZero()) {
3508 rewriter.eraseOp(op);
3514 if (
auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3515 if (constant.getValue().isZero() == eraseIfZero) {
3516 rewriter.eraseOp(op);
3524template <
class Op,
bool EraseIfZero = false>
3526 PatternRewriter &rewriter) {
3531void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3533 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3534 results.add<patterns::AssertXWhenX>(
context);
3537void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3539 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3540 results.add<patterns::AssumeXWhenX>(
context);
3543void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3544 RewritePatternSet &results, MLIRContext *
context) {
3545 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3546 results.add<patterns::UnclockedAssumeIntrinsicXWhenX>(
context);
3549void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3551 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3558LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3559 PatternRewriter &rewriter) {
3561 if (op.use_empty()) {
3562 rewriter.eraseOp(op);
3569 if (op->hasOneUse() &&
3570 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3571 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3572 *op->user_begin()) ||
3573 (isa<CvtPrimOp>(*op->user_begin()) &&
3574 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3575 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3576 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3577 .getBitWidthOrSentinel() > 0))) {
3578 auto *modop = *op->user_begin();
3579 auto inv = InvalidValueOp::create(rewriter, op.getLoc(),
3580 modop->getResult(0).getType());
3581 rewriter.replaceAllOpUsesWith(modop, inv);
3582 rewriter.eraseOp(modop);
3583 rewriter.eraseOp(op);
3589OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3590 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3591 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3599OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3608 return BoolAttr::get(getContext(),
false);
3612 return BoolAttr::get(getContext(),
false);
3617LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3618 PatternRewriter &rewriter) {
3620 if (
auto testEnable = op.getTestEnable()) {
3621 if (
auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3622 if (constOp.getValue().isZero()) {
3623 rewriter.modifyOpInPlace(op,
3624 [&] { op.getTestEnableMutable().clear(); });
3640 auto forceable = op.getRef().getDefiningOp<Forceable>();
3641 if (!forceable || !forceable.isForceable() ||
3642 op.getRef() != forceable.getDataRef() ||
3643 op.getType() != forceable.getDataType())
3645 rewriter.replaceAllUsesWith(op, forceable.getData());
3649void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3651 results.insert<patterns::RefResolveOfRefSend>(
context);
3655OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3657 if (getInput().getType() == getType())
3663 auto constOp = operand.getDefiningOp<ConstantOp>();
3664 return constOp && constOp.getValue().isZero();
3667template <
typename Op>
3670 rewriter.eraseOp(op);
3676void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3678 results.add(eraseIfPredFalse<RefForceOp>);
3680void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3682 results.add(eraseIfPredFalse<RefForceInitialOp>);
3684void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3686 results.add(eraseIfPredFalse<RefReleaseOp>);
3688void RefReleaseInitialOp::getCanonicalizationPatterns(
3689 RewritePatternSet &results, MLIRContext *
context) {
3690 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3697OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3703 if (adaptor.getReset())
3708 if (
isUInt1(getReset().getType()) && adaptor.getClock())
3721 [&](
auto ty) ->
bool {
return isTypeEmpty(ty.getElementType()); })
3722 .Case<BundleType>([&](
auto ty) ->
bool {
3723 for (
auto elem : ty.getElements())
3728 .Case<IntType>([&](
auto ty) {
return ty.getWidth() == 0; })
3729 .Default([](
auto) ->
bool {
return false; });
3732LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3733 PatternRewriter &rewriter) {
3734 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3741 rewriter.eraseOp(op);
3749LogicalResult LayerBlockOp::canonicalize(LayerBlockOp op,
3750 PatternRewriter &rewriter) {
3753 if (op.getBody()->empty()) {
3754 rewriter.eraseOp(op);
3765OpFoldResult UnsafeDomainCastOp::fold(FoldAdaptor adaptor) {
3767 if (getDomains().
empty())
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion)
Replaces the given op with the contents of the given single-block region.
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static bool isOkToPropagateName(Operation *op)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
static Location getLoc(DefSlot slot)
static InstancePath empty
AndRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
OrRCat(MLIRContext *context)
bool getIdentityValue() const override
Return the unit value for this reduction operation:
virtual bool getIdentityValue() const =0
Return the unit value for this reduction operation:
virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp constantOp, SmallVectorImpl< Value > &remaining) const =0
Handle a constant operand in the cat operation.
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
XorRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
BitsOfCat(MLIRContext *context)