20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/ADT/APSInt.h"
23#include "llvm/ADT/SmallPtrSet.h"
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/ADT/TypeSwitch.h"
28using namespace firrtl;
32static Value
dropWrite(PatternRewriter &rewriter, OpResult old,
34 SmallPtrSet<Operation *, 8> users;
35 for (
auto *user : old.getUsers())
37 for (Operation *user : users)
38 if (
auto connect = dyn_cast<FConnectLike>(user))
39 if (connect.getDest() == old)
40 rewriter.eraseOp(user);
50 if (op->getNumRegions() != 0)
52 return mlir::isPure(op) || isa<NodeOp, WireOp>(op);
60 Operation *op = passthrough.getDefiningOp();
63 assert(op &&
"passthrough must be an operation");
64 Operation *oldOp = old.getOwner();
65 auto name = oldOp->getAttrOfType<StringAttr>(
"name");
67 op->setAttr(
"name", name);
75#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
83 auto resultType = type_cast<IntType>(op->getResult(0).getType());
84 if (!resultType.hasWidth())
86 for (Value operand : op->getOperands())
87 if (!type_cast<IntType>(operand.getType()).hasWidth())
94 auto t = type_dyn_cast<UIntType>(type);
95 if (!t || !t.hasWidth() || t.getWidth() != 1)
102static void updateName(PatternRewriter &rewriter, Operation *op,
107 assert((!isa<InstanceOp, RegOp, RegResetOp>(op)) &&
"Should never rename");
108 auto newName = name.getValue();
109 auto newOpName = op->getAttrOfType<StringAttr>(
"name");
112 newName =
chooseName(newOpName.getValue(), name.getValue());
114 if (!newOpName || newOpName.getValue() != newName)
115 rewriter.modifyOpInPlace(
116 op, [&] { op->setAttr(
"name", rewriter.getStringAttr(newName)); });
124 if (
auto *newOp = newValue.getDefiningOp()) {
125 auto name = op->getAttrOfType<StringAttr>(
"name");
128 rewriter.replaceOp(op, newValue);
134template <
typename OpTy,
typename... Args>
136 Operation *op, Args &&...args) {
137 auto name = op->getAttrOfType<StringAttr>(
"name");
139 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
147 if (
auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
148 return namableOp.hasDroppableName();
159static std::optional<APSInt>
161 assert(type_cast<IntType>(operand.getType()) &&
162 "getExtendedConstant is limited to integer types");
169 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
174 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
175 return APSInt(destWidth,
176 type_cast<IntType>(operand.getType()).isUnsigned());
184 if (
auto attr = dyn_cast<BoolAttr>(operand))
185 return APSInt(APInt(1, attr.getValue()));
186 if (
auto attr = dyn_cast<IntegerAttr>(operand))
187 return attr.getAPSInt();
195 return cst->isZero();
222 Operation *op, ArrayRef<Attribute> operands,
BinOpKind opKind,
223 const function_ref<APInt(
const APSInt &,
const APSInt &)> &calculate) {
224 assert(operands.size() == 2 &&
"binary op takes two operands");
227 auto resultType = type_cast<IntType>(op->getResult(0).getType());
228 if (resultType.getWidthOrSentinel() < 0)
232 if (resultType.getWidthOrSentinel() == 0)
233 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
239 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
241 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
242 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
243 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
244 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
245 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
249 int32_t operandWidth;
252 operandWidth = resultType.getWidthOrSentinel();
257 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
261 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
272 APInt resultValue = calculate(*lhs, *rhs);
277 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
279 assert((
unsigned)resultType.getWidthOrSentinel() ==
280 resultValue.getBitWidth());
293 Operation *op, PatternRewriter &rewriter,
294 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
296 if (op->getNumResults() != 1)
298 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
303 auto width = type.getBitWidthOrSentinel();
308 SmallVector<Attribute, 3> constOperands;
309 constOperands.reserve(op->getNumOperands());
310 for (
auto operand : op->getOperands()) {
312 if (
auto *defOp = operand.getDefiningOp())
313 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
314 [&](
auto op) { attr = op.getValueAttr(); });
315 constOperands.push_back(attr);
320 auto result = canonicalize(constOperands);
324 if (
auto cst = dyn_cast<Attribute>(result))
325 resultValue = op->getDialect()
326 ->materializeConstant(rewriter, cst, type, op->getLoc())
329 resultValue = cast<Value>(result);
333 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
334 resultValue = PadPrimOp::create(rewriter, op->getLoc(), resultValue, width);
337 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
338 resultValue = AsSIntPrimOp::create(rewriter, op->getLoc(), resultValue);
339 else if (type_isa<UIntType>(type) &&
340 type_isa<SIntType>(resultValue.getType()))
341 resultValue = AsUIntPrimOp::create(rewriter, op->getLoc(), resultValue);
343 assert(type == resultValue.getType() &&
"canonicalization changed type");
351 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
357 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
363 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
370OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
371 assert(adaptor.getOperands().empty() &&
"constant has no operands");
372 return getValueAttr();
375OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
376 assert(adaptor.getOperands().empty() &&
"constant has no operands");
377 return getValueAttr();
380OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
381 assert(adaptor.getOperands().empty() &&
"constant has no operands");
382 return getFieldsAttr();
385OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
386 assert(adaptor.getOperands().empty() &&
"constant has no operands");
387 return getValueAttr();
390OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
391 assert(adaptor.getOperands().empty() &&
"constant has no operands");
392 return getValueAttr();
395OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
396 assert(adaptor.getOperands().empty() &&
"constant has no operands");
397 return getValueAttr();
400OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
401 assert(adaptor.getOperands().empty() &&
"constant has no operands");
402 return getValueAttr();
409OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
412 [=](
const APSInt &a,
const APSInt &b) { return a + b; });
415void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
417 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
418 patterns::AddOfSelf, patterns::AddOfPad>(
context);
421OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
424 [=](
const APSInt &a,
const APSInt &b) { return a - b; });
427void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
429 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
430 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
431 patterns::SubOfPadL, patterns::SubOfPadR>(
context);
434OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
446 [=](
const APSInt &a,
const APSInt &b) { return a * b; });
449OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
456 if (getLhs() == getRhs()) {
457 auto width = getType().base().getWidthOrSentinel();
462 return getIntAttr(getType(), APInt(width, 1));
479 if (
auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
480 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
485 [=](
const APSInt &a,
const APSInt &b) -> APInt {
488 return APInt(a.getBitWidth(), 0);
492OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
499 if (getLhs() == getRhs())
513 [=](
const APSInt &a,
const APSInt &b) -> APInt {
516 return APInt(a.getBitWidth(), 0);
520OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
523 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
526OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
529 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
532OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
535 [=](
const APSInt &a,
const APSInt &b) -> APInt {
536 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
542OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
545 if (rhsCst->isZero())
549 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
550 getRhs().getType() == getType())
556 if (lhsCst->isZero())
560 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
561 getRhs().getType() == getType())
566 if (getLhs() == getRhs() && getRhs().getType() == getType())
571 [](
const APSInt &a,
const APSInt &b) -> APInt { return a & b; });
574void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
577 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
578 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
579 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(
context);
582OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
585 if (rhsCst->isZero() && getLhs().getType() == getType())
589 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
590 getLhs().getType() == getType())
596 if (lhsCst->isZero() && getRhs().getType() == getType())
600 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
601 getRhs().getType() == getType())
606 if (getLhs() == getRhs() && getRhs().getType() == getType())
611 [](
const APSInt &a,
const APSInt &b) -> APInt { return a | b; });
614void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
616 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
617 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
621OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
624 if (rhsCst->isZero() &&
630 if (lhsCst->isZero() &&
635 if (getLhs() == getRhs())
638 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
642 [](
const APSInt &a,
const APSInt &b) -> APInt { return a ^ b; });
645void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
647 results.insert<patterns::extendXor, patterns::moveConstXor,
648 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
652void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
654 results.insert<patterns::LEQWithConstLHS>(
context);
657OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
658 bool isUnsigned = getLhs().getType().base().isUnsigned();
661 if (getLhs() == getRhs())
665 if (
auto width = getLhs().getType().base().
getWidth()) {
667 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
668 commonWidth = std::max(commonWidth, 1);
679 if (isUnsigned && rhsCst->zext(commonWidth)
692 [=](
const APSInt &a,
const APSInt &b) -> APInt {
693 return APInt(1, a <= b);
697void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
699 results.insert<patterns::LTWithConstLHS>(
context);
702OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
703 IntType lhsType = getLhs().getType();
707 if (getLhs() == getRhs())
717 if (
auto width = lhsType.
getWidth()) {
719 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
720 commonWidth = std::max(commonWidth, 1);
731 if (isUnsigned && rhsCst->zext(commonWidth)
744 [=](
const APSInt &a,
const APSInt &b) -> APInt {
745 return APInt(1, a < b);
749void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
751 results.insert<patterns::GEQWithConstLHS>(
context);
754OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
755 IntType lhsType = getLhs().getType();
759 if (getLhs() == getRhs())
764 if (rhsCst->isZero() && isUnsigned)
769 if (
auto width = lhsType.
getWidth()) {
771 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
772 commonWidth = std::max(commonWidth, 1);
775 if (isUnsigned && rhsCst->zext(commonWidth)
796 [=](
const APSInt &a,
const APSInt &b) -> APInt {
797 return APInt(1, a >= b);
801void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
803 results.insert<patterns::GTWithConstLHS>(
context);
806OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
807 IntType lhsType = getLhs().getType();
811 if (getLhs() == getRhs())
815 if (
auto width = lhsType.
getWidth()) {
817 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
818 commonWidth = std::max(commonWidth, 1);
821 if (isUnsigned && rhsCst->zext(commonWidth)
842 [=](
const APSInt &a,
const APSInt &b) -> APInt {
843 return APInt(1, a > b);
847OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
849 if (getLhs() == getRhs())
855 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
856 getRhs().getType() == getType())
862 [=](
const APSInt &a,
const APSInt &b) -> APInt {
863 return APInt(1, a == b);
867LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
869 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
871 auto width = op.getLhs().getType().getBitWidthOrSentinel();
874 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
875 op.getRhs().getType() == op.getType()) {
876 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
881 if (rhsCst->isZero() && width > 1) {
882 auto orrOp = OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
883 return NotPrimOp::create(rewriter, op.getLoc(), orrOp).getResult();
887 if (rhsCst->isAllOnes() && width > 1 &&
888 op.getLhs().getType() == op.getRhs().getType()) {
889 return AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
897OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
899 if (getLhs() == getRhs())
905 if (rhsCst->isZero() && getLhs().getType() == getType() &&
906 getRhs().getType() == getType())
912 [=](
const APSInt &a,
const APSInt &b) -> APInt {
913 return APInt(1, a != b);
917LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
919 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
921 auto width = op.getLhs().getType().getBitWidthOrSentinel();
924 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
925 op.getRhs().getType() == op.getType()) {
926 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
931 if (rhsCst->isZero() && width > 1) {
932 return OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
937 if (rhsCst->isAllOnes() && width > 1 &&
938 op.getLhs().getType() == op.getRhs().getType()) {
940 AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
941 return NotPrimOp::create(rewriter, op.getLoc(), andrOp).getResult();
949OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
955OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
961OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
965 return IntegerAttr::get(IntegerType::get(getContext(),
966 lhsCst->getBitWidth(),
967 IntegerType::Signed),
968 lhsCst->ashr(*rhsCst));
971 if (rhsCst->isZero())
978OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
983 return IntegerAttr::get(IntegerType::get(getContext(),
984 lhsCst->getBitWidth(),
985 IntegerType::Signed),
986 lhsCst->shl(*rhsCst));
989 if (rhsCst->isZero())
1000OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
1001 auto base = getInput().getType();
1008OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
1015OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
1023 if (getType().base().hasWidth())
1030void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1032 results.insert<patterns::StoUtoS>(
context);
1035OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1043 if (getType().base().hasWidth())
1050void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1052 results.insert<patterns::UtoStoU>(
context);
1055OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1057 if (getInput().getType() == getType())
1062 return BoolAttr::get(getContext(), cst->getBoolValue());
1067OpFoldResult AsResetPrimOp::fold(FoldAdaptor adaptor) {
1069 return BoolAttr::get(getContext(), cst->getBoolValue());
1073OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1075 if (getInput().getType() == getType())
1080 return BoolAttr::get(getContext(), cst->getBoolValue());
1085OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1091 getType().base().getWidthOrSentinel()))
1097void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1099 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(
context);
1102OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1109 getType().base().getWidthOrSentinel()))
1110 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1115OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1120 getType().base().getWidthOrSentinel()))
1126void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1128 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1129 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1136 : RewritePattern(opName, 0,
context) {}
1142 ConstantOp constantOp,
1143 SmallVectorImpl<Value> &remaining)
const = 0;
1150 mlir::PatternRewriter &rewriter)
const override {
1152 auto catOp = op->getOperand(0).getDefiningOp<CatPrimOp>();
1156 SmallVector<Value> nonConstantOperands;
1159 for (
auto operand : catOp.getInputs()) {
1160 if (
auto constantOp = operand.getDefiningOp<ConstantOp>()) {
1162 if (
handleConstant(rewriter, op, constantOp, nonConstantOperands))
1166 nonConstantOperands.push_back(operand);
1171 if (nonConstantOperands.empty()) {
1172 replaceOpWithNewOpAndCopyName<ConstantOp>(
1173 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1179 if (nonConstantOperands.size() == 1) {
1180 rewriter.modifyOpInPlace(
1181 op, [&] { op->setOperand(0, nonConstantOperands.front()); });
1186 if (catOp->hasOneUse() &&
1187 nonConstantOperands.size() < catOp->getNumOperands()) {
1188 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, catOp,
1189 nonConstantOperands);
1202 SmallVectorImpl<Value> &remaining)
const override {
1203 if (value.getValue().isZero())
1206 replaceOpWithNewOpAndCopyName<ConstantOp>(
1207 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1220 SmallVectorImpl<Value> &remaining)
const override {
1221 if (value.getValue().isAllOnes())
1224 replaceOpWithNewOpAndCopyName<ConstantOp>(
1225 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1238 SmallVectorImpl<Value> &remaining)
const override {
1239 if (value.getValue().isZero())
1241 remaining.push_back(value);
1247OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1251 if (getInput().getType().getBitWidthOrSentinel() == 0)
1256 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1260 if (
isUInt1(getInput().getType()))
1266void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1268 results.insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1269 patterns::AndRPadS, patterns::AndRCatAndR_left,
1273OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1277 if (getInput().getType().getBitWidthOrSentinel() == 0)
1282 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1286 if (
isUInt1(getInput().getType()))
1292void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1294 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1295 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right,
OrRCat>(
1299OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1303 if (getInput().getType().getBitWidthOrSentinel() == 0)
1308 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1311 if (
isUInt1(getInput().getType()))
1317void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1320 .insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1321 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right,
XorRCat>(
1329OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1330 auto inputs = getInputs();
1331 auto inputAdaptors = adaptor.getInputs();
1338 if (inputs.size() == 1 && inputs[0].getType() == getType())
1346 SmallVector<Value> nonZeroInputs;
1347 SmallVector<Attribute> nonZeroAttributes;
1348 bool allConstant =
true;
1349 for (
auto [input, attr] :
llvm::zip(inputs, inputAdaptors)) {
1350 auto inputType = type_cast<IntType>(input.getType());
1351 if (inputType.getBitWidthOrSentinel() != 0) {
1352 nonZeroInputs.push_back(input);
1354 allConstant =
false;
1355 if (nonZeroInputs.size() > 1 && !allConstant)
1361 if (nonZeroInputs.empty())
1365 if (nonZeroInputs.size() == 1 && nonZeroInputs[0].getType() == getType())
1366 return nonZeroInputs[0];
1372 SmallVector<APInt> constants;
1373 for (
auto inputAdaptor : inputAdaptors) {
1375 constants.push_back(*cst);
1380 assert(!constants.empty());
1382 APInt result = constants[0];
1383 for (
size_t i = 1; i < constants.size(); ++i)
1384 result = result.concat(constants[i]);
1389void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1391 results.insert<patterns::DShlOfConstant>(
context);
1394void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1396 results.insert<patterns::DShrOfConstant>(
context);
1402class FlattenCat :
public mlir::RewritePattern {
1404 FlattenCat(MLIRContext *
context)
1405 : RewritePattern(CatPrimOp::getOperationName(), 0,
context) {}
1408 matchAndRewrite(Operation *op,
1409 mlir::PatternRewriter &rewriter)
const override {
1410 auto cat = cast<CatPrimOp>(op);
1412 cat.getType().getBitWidthOrSentinel() == 0)
1416 if (cat->hasOneUse() && isa<CatPrimOp>(*cat->getUsers().begin()))
1420 SmallVector<Value> operands;
1421 SmallVector<Value> worklist;
1422 auto pushOperands = [&worklist](CatPrimOp op) {
1423 for (
auto operand :
llvm::reverse(op.getInputs()))
1424 worklist.push_back(operand);
1427 bool hasSigned =
false, hasUnsigned =
false;
1428 while (!worklist.empty()) {
1429 auto value = worklist.pop_back_val();
1430 auto catOp = value.getDefiningOp<CatPrimOp>();
1432 operands.push_back(value);
1433 (type_isa<UIntType>(value.getType()) ? hasUnsigned : hasSigned) =
true;
1437 pushOperands(catOp);
1442 auto castToUIntIfSigned = [&](Value value) -> Value {
1443 if (type_isa<UIntType>(value.getType()))
1445 return AsUIntPrimOp::create(rewriter, value.getLoc(), value);
1448 assert(operands.size() >= 1 &&
"zero width cast must be rejected");
1450 if (operands.size() == 1) {
1451 rewriter.replaceOp(op, castToUIntIfSigned(operands[0]));
1455 if (operands.size() == cat->getNumOperands())
1459 if (hasSigned && hasUnsigned)
1460 for (
auto &operand : operands)
1461 operand = castToUIntIfSigned(operand);
1463 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, op, cat.getType(),
1470class CatOfConstant :
public mlir::RewritePattern {
1472 CatOfConstant(MLIRContext *
context)
1473 : RewritePattern(CatPrimOp::getOperationName(), 0,
context) {}
1476 matchAndRewrite(Operation *op,
1477 mlir::PatternRewriter &rewriter)
const override {
1478 auto cat = cast<CatPrimOp>(op);
1482 SmallVector<Value> operands;
1484 for (
size_t i = 0; i < cat->getNumOperands(); ++i) {
1485 auto cst = cat.getInputs()[i].getDefiningOp<ConstantOp>();
1487 operands.push_back(cat.getInputs()[i]);
1490 APSInt value = cst.getValue();
1492 for (; j < cat->getNumOperands(); ++j) {
1493 auto nextCst = cat.getInputs()[j].getDefiningOp<ConstantOp>();
1496 value = value.concat(nextCst.getValue());
1501 operands.push_back(cst);
1504 operands.push_back(ConstantOp::create(rewriter, cat.getLoc(), value));
1510 if (operands.size() == cat->getNumOperands())
1513 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, op, cat.getType(),
1522void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1524 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1525 patterns::CatCast, FlattenCat, CatOfConstant>(
context);
1528OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1531 if (op.getType() == op.getInput().getType())
1532 return op.getInput();
1536 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1537 if (op.getType() == in.getInput().getType())
1538 return in.getInput();
1543OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1544 IntType inputType = getInput().getType();
1545 IntType resultType = getType();
1547 if (inputType == getType() && resultType.
hasWidth())
1554 cst->extractBits(getHi() - getLo() + 1, getLo()));
1561 : RewritePattern(BitsPrimOp::getOperationName(), 0,
context) {}
1565 mlir::PatternRewriter &rewriter)
const override {
1566 auto bits = cast<BitsPrimOp>(op);
1567 auto cat = bits.getInput().getDefiningOp<CatPrimOp>();
1570 int32_t bitPos = bits.getLo();
1571 auto resultWidth = type_cast<UIntType>(bits.getType()).getWidthOrSentinel();
1572 if (resultWidth < 0)
1574 for (
auto operand : llvm::reverse(cat.getInputs())) {
1576 type_cast<IntType>(operand.getType()).getWidthOrSentinel();
1577 if (operandWidth < 0)
1579 if (bitPos < operandWidth) {
1580 if (bitPos + resultWidth <= operandWidth) {
1581 auto newBits = rewriter.createOrFold<BitsPrimOp>(
1582 op->getLoc(), operand, bitPos + resultWidth - 1, bitPos);
1588 bitPos -= operandWidth;
1594void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1597 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1605 unsigned loBit, PatternRewriter &rewriter) {
1606 auto resType = type_cast<IntType>(op->getResult(0).getType());
1607 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1608 value = BitsPrimOp::create(rewriter, op->getLoc(), value, hiBit, loBit);
1610 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1611 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1612 }
else if (resType.isUnsigned() &&
1613 !type_cast<IntType>(value.getType()).isUnsigned()) {
1614 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1616 rewriter.replaceOp(op, value);
1619template <
typename OpTy>
1620static OpFoldResult
foldMux(OpTy op,
typename OpTy::FoldAdaptor adaptor) {
1622 if (op.getType().getBitWidthOrSentinel() == 0)
1624 APInt(0, 0, op.getType().isSignedInteger()));
1627 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1628 return op.getHigh();
1633 if (op.getType().getBitWidthOrSentinel() < 0)
1638 if (cond->isZero() && op.getLow().getType() == op.getType())
1640 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1641 return op.getHigh();
1645 if (
auto lowCst =
getConstant(adaptor.getLow())) {
1647 if (
auto highCst =
getConstant(adaptor.getHigh())) {
1649 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1650 *highCst == *lowCst)
1653 if (highCst->isOne() && lowCst->isZero() &&
1654 op.getType() == op.getSel().getType())
1667OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1668 return foldMux(*
this, adaptor);
1671OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1672 return foldMux(*
this, adaptor);
1675OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) {
return {}; }
1682class MuxPad :
public mlir::RewritePattern {
1685 : RewritePattern(MuxPrimOp::getOperationName(), 0,
context) {}
1688 matchAndRewrite(Operation *op,
1689 mlir::PatternRewriter &rewriter)
const override {
1690 auto mux = cast<MuxPrimOp>(op);
1691 auto width = mux.getType().getBitWidthOrSentinel();
1695 auto pad = [&](Value input) -> Value {
1697 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1698 if (inputWidth < 0 || width == inputWidth)
1700 return PadPrimOp::create(rewriter, mux.getLoc(), mux.getType(), input,
1705 auto newHigh = pad(mux.getHigh());
1706 auto newLow = pad(mux.getLow());
1707 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1710 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1711 rewriter, op, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1719class MuxSharedCond :
public mlir::RewritePattern {
1721 MuxSharedCond(MLIRContext *
context)
1722 : RewritePattern(MuxPrimOp::getOperationName(), 0,
context) {}
1724 static const int depthLimit = 5;
1726 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1727 mlir::PatternRewriter &rewriter,
1728 bool updateInPlace)
const {
1729 if (updateInPlace) {
1730 rewriter.modifyOpInPlace(mux, [&] {
1731 mux.setOperand(1, high);
1732 mux.setOperand(2, low);
1736 rewriter.setInsertionPointAfter(mux);
1737 return MuxPrimOp::create(rewriter, mux.getLoc(), mux.getType(),
1738 ValueRange{mux.getSel(), high, low})
1743 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1744 bool updateInPlace,
int limit)
const {
1745 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1748 if (mux.getSel() == cond)
1749 return mux.getHigh();
1750 if (limit > depthLimit)
1752 updateInPlace &= mux->hasOneUse();
1754 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1756 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1759 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1760 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1765 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1766 bool updateInPlace,
int limit)
const {
1767 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1770 if (mux.getSel() == cond)
1771 return mux.getLow();
1772 if (limit > depthLimit)
1774 updateInPlace &= mux->hasOneUse();
1776 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1778 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1780 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1782 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1788 matchAndRewrite(Operation *op,
1789 mlir::PatternRewriter &rewriter)
const override {
1790 auto mux = cast<MuxPrimOp>(op);
1791 auto width = mux.getType().getBitWidthOrSentinel();
1795 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter,
true, 0)) {
1796 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1800 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter,
true, 0)) {
1801 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1810void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1813 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1814 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
1815 patterns::MuxSameTrue, patterns::MuxSameFalse,
1816 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
1820void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
1821 RewritePatternSet &results, MLIRContext *
context) {
1822 results.add<patterns::Mux2PadSel>(
context);
1825void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
1826 RewritePatternSet &results, MLIRContext *
context) {
1827 results.add<patterns::Mux4PadSel>(
context);
1830OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1831 auto input = this->getInput();
1834 if (input.getType() == getType())
1838 auto inputType = input.getType().base();
1845 auto destWidth = getType().base().getWidthOrSentinel();
1846 if (destWidth == -1)
1849 if (inputType.
isSigned() && cst->getBitWidth())
1850 return getIntAttr(getType(), cst->sext(destWidth));
1851 return getIntAttr(getType(), cst->zext(destWidth));
1857OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1858 auto input = this->getInput();
1859 IntType inputType = input.getType();
1860 int shiftAmount = getAmount();
1863 if (shiftAmount == 0)
1869 if (inputWidth != -1) {
1870 auto resultWidth = inputWidth + shiftAmount;
1871 shiftAmount = std::min(shiftAmount, resultWidth);
1872 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
1878OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
1879 auto input = this->getInput();
1880 IntType inputType = input.getType();
1881 int shiftAmount = getAmount();
1887 if (shiftAmount == 0 && inputWidth > 0)
1890 if (inputWidth == -1)
1892 if (inputWidth == 0)
1897 if (shiftAmount >= inputWidth && inputType.
isUnsigned())
1898 return getIntAttr(getType(), APInt(0, 0,
false));
1904 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
1906 value = cst->lshr(std::min(shiftAmount, inputWidth));
1907 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
1908 return getIntAttr(getType(), value.trunc(resultWidth));
1913LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
1914 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1915 if (inputWidth <= 0)
1919 unsigned shiftAmount = op.getAmount();
1920 if (
int(shiftAmount) >= inputWidth) {
1922 if (op.getType().base().isUnsigned())
1928 shiftAmount = inputWidth - 1;
1931 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
1935LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
1936 PatternRewriter &rewriter) {
1937 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1938 if (inputWidth <= 0)
1942 unsigned keepAmount = op.getAmount();
1944 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
1949OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
1953 getInput().getType().base().getWidthOrSentinel() - getAmount();
1954 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
1960OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
1964 cst->trunc(getType().base().getWidthOrSentinel()));
1968LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
1969 PatternRewriter &rewriter) {
1970 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1971 if (inputWidth <= 0)
1975 unsigned dropAmount = op.getAmount();
1976 if (dropAmount !=
unsigned(inputWidth))
1982void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
1984 results.add<patterns::SubaccessOfConstant>(
context);
1987OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
1989 if (adaptor.getInputs().size() == 1)
1990 return getOperand(1);
1992 if (
auto constIndex =
getConstant(adaptor.getIndex())) {
1993 auto index = constIndex->getZExtValue();
1994 if (index < getInputs().size())
1995 return getInputs()[getInputs().size() - 1 - index];
2001LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
2002 PatternRewriter &rewriter) {
2006 if (llvm::all_of(op.getInputs().drop_front(), [&](
auto input) {
2007 return input == op.getInputs().front();
2015 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
2016 uint64_t inputSize = op.getInputs().size();
2017 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
2018 rewriter.modifyOpInPlace(op, [&]() {
2019 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
2026 if (
auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
2027 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](
auto e) {
2028 auto subindex = e.value().template getDefiningOp<SubindexOp>();
2029 return subindex && lastSubindex.getInput() == subindex.getInput() &&
2030 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
2032 replaceOpWithNewOpAndCopyName<SubaccessOp>(
2033 rewriter, op, lastSubindex.getInput(), op.getIndex());
2039 if (op.getInputs().size() != 2)
2043 auto uintType = op.getIndex().getType();
2044 if (uintType.getBitWidthOrSentinel() != 1)
2048 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
2049 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
2068 MatchingConnectOp connect;
2069 for (Operation *user : value.getUsers()) {
2071 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
2074 if (
auto aConnect = dyn_cast<FConnectLike>(user))
2075 if (aConnect.getDest() == value) {
2076 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
2079 if (!matchingConnect || (connect && connect != matchingConnect) ||
2080 matchingConnect->getBlock() != value.getParentBlock())
2082 connect = matchingConnect;
2090 PatternRewriter &rewriter) {
2093 Operation *connectedDecl = op.getDest().getDefiningOp();
2098 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
2102 cast<Forceable>(connectedDecl).isForceable())
2110 if (connectedDecl->hasOneUse())
2114 auto *declBlock = connectedDecl->getBlock();
2115 auto *srcValueOp = op.getSrc().getDefiningOp();
2118 if (!isa<WireOp>(connectedDecl))
2124 if (!isa<ConstantOp>(srcValueOp))
2126 if (srcValueOp->getBlock() != declBlock)
2132 auto replacement = op.getSrc();
2135 if (srcValueOp && srcValueOp != &declBlock->front())
2136 srcValueOp->moveBefore(&declBlock->front());
2143 rewriter.eraseOp(op);
2147void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2149 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
2153LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
2154 PatternRewriter &rewriter) {
2171 for (
auto *user : value.getUsers()) {
2172 auto attach = dyn_cast<AttachOp>(user);
2173 if (!attach || attach == dominatedAttach)
2175 if (attach->isBeforeInBlock(dominatedAttach))
2181LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
2183 if (op.getNumOperands() <= 1) {
2184 rewriter.eraseOp(op);
2188 for (
auto operand : op.getOperands()) {
2195 SmallVector<Value> newOperands(op.getOperands());
2196 for (
auto newOperand : attach.getOperands())
2197 if (newOperand != operand)
2198 newOperands.push_back(newOperand);
2199 AttachOp::create(rewriter, op->getLoc(), newOperands);
2200 rewriter.eraseOp(attach);
2201 rewriter.eraseOp(op);
2209 if (
auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
2210 if (!
hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
2211 !wire.isForceable()) {
2212 SmallVector<Value> newOperands;
2213 for (
auto newOperand : op.getOperands())
2214 if (newOperand != operand)
2215 newOperands.push_back(newOperand);
2217 AttachOp::create(rewriter, op->getLoc(), newOperands);
2218 rewriter.eraseOp(op);
2219 rewriter.eraseOp(wire);
2230 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
2231 rewriter.inlineBlockBefore(®ion.front(), op, {});
2234LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
2235 if (
auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
2236 if (constant.getValue().isAllOnes())
2238 else if (op.hasElseRegion() && !op.getElseRegion().empty())
2241 rewriter.eraseOp(op);
2247 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
2248 op.getElseBlock().empty()) {
2249 rewriter.eraseBlock(&op.getElseBlock());
2256 if (!op.getThenBlock().empty())
2260 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
2261 rewriter.eraseOp(op);
2270struct FoldNodeName :
public mlir::RewritePattern {
2271 FoldNodeName(MLIRContext *
context)
2272 : RewritePattern(NodeOp::getOperationName(), 0,
context) {}
2273 LogicalResult matchAndRewrite(Operation *op,
2274 PatternRewriter &rewriter)
const override {
2275 auto node = cast<NodeOp>(op);
2276 auto name = node.getNameAttr();
2277 if (!node.hasDroppableName() || node.getInnerSym() ||
2280 auto *newOp = node.getInput().getDefiningOp();
2283 rewriter.replaceOp(node, node.getInput());
2289struct NodeBypass :
public mlir::RewritePattern {
2290 NodeBypass(MLIRContext *
context)
2291 : RewritePattern(NodeOp::getOperationName(), 0,
context) {}
2292 LogicalResult matchAndRewrite(Operation *op,
2293 PatternRewriter &rewriter)
const override {
2294 auto node = cast<NodeOp>(op);
2296 node.use_empty() || node.isForceable())
2298 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
2305template <
typename OpTy>
2307 PatternRewriter &rewriter) {
2308 if (!op.isForceable() || !op.getDataRef().use_empty())
2316LogicalResult NodeOp::fold(FoldAdaptor adaptor,
2317 SmallVectorImpl<OpFoldResult> &results) {
2326 if (!adaptor.getInput())
2329 results.push_back(adaptor.getInput());
2333void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2335 results.insert<FoldNodeName>(
context);
2336 results.add(demoteForceableIfUnused<NodeOp>);
2342struct AggOneShot :
public mlir::RewritePattern {
2343 AggOneShot(StringRef name, uint32_t weight, MLIRContext *
context)
2344 : RewritePattern(name, 0,
context) {}
2346 SmallVector<Value> getCompleteWrite(Operation *lhs)
const {
2347 auto lhsTy = lhs->getResult(0).getType();
2348 if (!type_isa<BundleType, FVectorType>(lhsTy))
2351 DenseMap<uint32_t, Value> fields;
2352 for (Operation *user : lhs->getResult(0).getUsers()) {
2353 if (user->getParentOp() != lhs->getParentOp())
2355 if (
auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2356 if (aConnect.getDest() == lhs->getResult(0))
2358 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2359 for (Operation *subuser : subField.getResult().getUsers()) {
2360 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2361 if (aConnect.getDest() == subField) {
2362 if (subuser->getParentOp() != lhs->getParentOp())
2364 if (fields.count(subField.getFieldIndex()))
2366 fields[subField.getFieldIndex()] = aConnect.getSrc();
2372 }
else if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2373 for (Operation *subuser : subIndex.getResult().getUsers()) {
2374 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2375 if (aConnect.getDest() == subIndex) {
2376 if (subuser->getParentOp() != lhs->getParentOp())
2378 if (fields.count(subIndex.getIndex()))
2380 fields[subIndex.getIndex()] = aConnect.getSrc();
2391 SmallVector<Value> values;
2392 uint32_t total = type_isa<BundleType>(lhsTy)
2393 ? type_cast<BundleType>(lhsTy).getNumElements()
2394 : type_cast<FVectorType>(lhsTy).getNumElements();
2395 for (uint32_t i = 0; i < total; ++i) {
2396 if (!fields.count(i))
2398 values.push_back(fields[i]);
2403 LogicalResult matchAndRewrite(Operation *op,
2404 PatternRewriter &rewriter)
const override {
2405 auto values = getCompleteWrite(op);
2408 rewriter.setInsertionPointToEnd(op->getBlock());
2409 auto dest = op->getResult(0);
2410 auto destType = dest.getType();
2413 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2416 Value newVal = type_isa<BundleType>(destType)
2417 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2419 : rewriter.createOrFold<VectorCreateOp>(
2420 op->
getLoc(), destType, values);
2421 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2422 for (Operation *user : dest.getUsers()) {
2423 if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2424 for (Operation *subuser :
2425 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2426 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2427 if (aConnect.getDest() == subIndex)
2428 rewriter.eraseOp(aConnect);
2429 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2430 for (Operation *subuser :
2431 llvm::make_early_inc_range(subField.getResult().getUsers()))
2432 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2433 if (aConnect.getDest() == subField)
2434 rewriter.eraseOp(aConnect);
2441struct WireAggOneShot :
public AggOneShot {
2442 WireAggOneShot(MLIRContext *
context)
2443 : AggOneShot(WireOp::getOperationName(), 0,
context) {}
2445struct SubindexAggOneShot :
public AggOneShot {
2446 SubindexAggOneShot(MLIRContext *
context)
2447 : AggOneShot(SubindexOp::getOperationName(), 0,
context) {}
2449struct SubfieldAggOneShot :
public AggOneShot {
2450 SubfieldAggOneShot(MLIRContext *
context)
2451 : AggOneShot(SubfieldOp::getOperationName(), 0,
context) {}
2455void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2457 results.insert<WireAggOneShot>(
context);
2458 results.add(demoteForceableIfUnused<WireOp>);
2461void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2463 results.insert<SubindexAggOneShot>(
context);
2466OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2467 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2470 return attr[getIndex()];
2473OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2474 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2477 auto index = getFieldIndex();
2481void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2483 results.insert<SubfieldAggOneShot>(
context);
2487 ArrayRef<Attribute> operands) {
2488 for (
auto operand : operands)
2491 return ArrayAttr::get(
context, operands);
2494OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2497 if (getNumOperands() > 0)
2498 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2499 if (first.getFieldIndex() == 0 &&
2500 first.getInput().getType() == getType() &&
2502 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2504 elem.value().
template getDefiningOp<SubfieldOp>();
2505 return subindex && subindex.getInput() == first.getInput() &&
2506 subindex.getFieldIndex() == elem.index();
2508 return first.getInput();
2513OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2516 if (getNumOperands() > 0)
2517 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2518 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2520 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2522 elem.value().
template getDefiningOp<SubindexOp>();
2523 return subindex && subindex.getInput() == first.getInput() &&
2524 subindex.getIndex() == elem.index();
2526 return first.getInput();
2531OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2532 if (getOperand().getType() == getType())
2533 return getOperand();
2540struct FoldResetMux :
public mlir::RewritePattern {
2541 FoldResetMux(MLIRContext *
context)
2542 : RewritePattern(RegResetOp::getOperationName(), 0,
context) {}
2543 LogicalResult matchAndRewrite(Operation *op,
2544 PatternRewriter &rewriter)
const override {
2545 auto reg = cast<RegResetOp>(op);
2547 dyn_cast_or_null<ConstantOp>(
reg.getResetValue().getDefiningOp());
2556 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2559 auto *high = mux.getHigh().getDefiningOp();
2560 auto *low = mux.getLow().getDefiningOp();
2561 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2563 if (constOp && low != reg)
2565 if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2566 constOp = dyn_cast<ConstantOp>(low);
2568 if (!constOp || constOp.getType() != reset.getType() ||
2569 constOp.getValue() != reset.getValue())
2573 auto regTy =
reg.getResult().getType();
2574 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2575 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2576 regTy.getBitWidthOrSentinel() < 0)
2582 if (constOp != &con->getBlock()->front())
2583 constOp->moveBefore(&con->getBlock()->front());
2588 rewriter.eraseOp(con);
2595 if (
auto c = v.getDefiningOp<ConstantOp>())
2596 return c.getValue().isOne();
2597 if (
auto sc = v.getDefiningOp<SpecialConstantOp>())
2598 return sc.getValue();
2607 auto resetValue = reg.getResetValue();
2608 if (reg.getType(0) != resetValue.getType())
2612 (void)
dropWrite(rewriter, reg->getResult(0), {});
2613 replaceOpWithNewOpAndCopyName<NodeOp>(
2614 rewriter, reg, resetValue, reg.getNameAttr(), reg.getNameKind(),
2615 reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2619void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2621 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(
context);
2623 results.add(demoteForceableIfUnused<RegResetOp>);
2628 auto portTy = type_cast<BundleType>(port.getType());
2629 auto fieldIndex = portTy.getElementIndex(name);
2630 assert(fieldIndex &&
"missing field on memory port");
2633 for (
auto *op : port.getUsers()) {
2634 auto portAccess = cast<SubfieldOp>(op);
2635 if (fieldIndex != portAccess.getFieldIndex())
2640 value = conn.getSrc();
2650 auto portConst = value.getDefiningOp<ConstantOp>();
2653 return portConst.getValue().isZero();
2658 auto portTy = type_cast<BundleType>(port.getType());
2659 auto fieldIndex = portTy.getElementIndex(
data);
2660 assert(fieldIndex &&
"missing enable flag on memory port");
2662 for (
auto *op : port.getUsers()) {
2663 auto portAccess = cast<SubfieldOp>(op);
2664 if (fieldIndex != portAccess.getFieldIndex())
2666 if (!portAccess.use_empty())
2675 StringRef name, Value value) {
2676 auto portTy = type_cast<BundleType>(port.getType());
2677 auto fieldIndex = portTy.getElementIndex(name);
2678 assert(fieldIndex &&
"missing field on memory port");
2680 for (
auto *op : llvm::make_early_inc_range(port.getUsers())) {
2681 auto portAccess = cast<SubfieldOp>(op);
2682 if (fieldIndex != portAccess.getFieldIndex())
2684 rewriter.replaceAllUsesWith(portAccess, value);
2685 rewriter.eraseOp(portAccess);
2690static void erasePort(PatternRewriter &rewriter, Value port) {
2693 auto getClock = [&] {
2695 clock = SpecialConstantOp::create(rewriter, port.getLoc(),
2696 ClockType::get(rewriter.getContext()),
2705 for (
auto *op : port.getUsers()) {
2706 auto subfield = dyn_cast<SubfieldOp>(op);
2708 auto ty = port.getType();
2709 auto reg = RegOp::create(rewriter, port.getLoc(), ty, getClock());
2710 rewriter.replaceAllUsesWith(port, reg.getResult());
2719 for (
auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2720 auto access = cast<SubfieldOp>(accessOp);
2721 for (
auto *user : llvm::make_early_inc_range(access->getUsers())) {
2722 auto connect = dyn_cast<FConnectLike>(user);
2723 if (connect && connect.getDest() == access) {
2724 rewriter.eraseOp(user);
2728 if (access.use_empty()) {
2729 rewriter.eraseOp(access);
2735 auto ty = access.getType();
2736 auto reg = RegOp::create(rewriter, access.getLoc(), ty, getClock());
2737 rewriter.replaceOp(access, reg.getResult());
2739 assert(port.use_empty() &&
"port should have no remaining uses");
2744struct FoldZeroWidthMemory :
public mlir::RewritePattern {
2745 FoldZeroWidthMemory(MLIRContext *
context)
2746 : RewritePattern(MemOp::getOperationName(), 0,
context) {}
2747 LogicalResult matchAndRewrite(Operation *op,
2748 PatternRewriter &rewriter)
const override {
2749 MemOp mem = cast<MemOp>(op);
2753 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2754 mem.getDataType().getBitWidthOrSentinel() != 0)
2758 for (
auto port : mem.getResults())
2759 for (auto *user : port.getUsers())
2760 if (!isa<SubfieldOp>(user))
2765 for (
auto port : op->getResults()) {
2766 for (
auto *user :
llvm::make_early_inc_range(port.getUsers())) {
2767 SubfieldOp sfop = cast<SubfieldOp>(user);
2768 StringRef fieldName = sfop.getFieldName();
2769 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2770 rewriter, sfop, sfop.getResult().getType())
2772 if (fieldName.ends_with(
"data")) {
2774 auto zero = firrtl::ConstantOp::create(
2775 rewriter, wire.getLoc(),
2776 firrtl::type_cast<IntType>(wire.getType()), APInt::getZero(0));
2777 MatchingConnectOp::create(rewriter, wire.getLoc(), wire, zero);
2781 rewriter.eraseOp(op);
2787struct FoldReadOrWriteOnlyMemory :
public mlir::RewritePattern {
2788 FoldReadOrWriteOnlyMemory(MLIRContext *
context)
2789 : RewritePattern(MemOp::getOperationName(), 0,
context) {}
2790 LogicalResult matchAndRewrite(Operation *op,
2791 PatternRewriter &rewriter)
const override {
2792 MemOp mem = cast<MemOp>(op);
2795 bool isRead =
false, isWritten =
false;
2796 for (
unsigned i = 0; i < mem.getNumResults(); ++i) {
2797 switch (mem.getPortKind(i)) {
2798 case MemOp::PortKind::Read:
2803 case MemOp::PortKind::Write:
2808 case MemOp::PortKind::Debug:
2809 case MemOp::PortKind::ReadWrite:
2812 llvm_unreachable(
"unknown port kind");
2814 assert((!isWritten || !isRead) &&
"memory is in use");
2819 if (isRead && mem.getInit())
2822 for (
auto port : mem.getResults())
2825 rewriter.eraseOp(op);
2831struct FoldUnusedPorts :
public mlir::RewritePattern {
2832 FoldUnusedPorts(MLIRContext *
context)
2833 : RewritePattern(MemOp::getOperationName(), 0,
context) {}
2834 LogicalResult matchAndRewrite(Operation *op,
2835 PatternRewriter &rewriter)
const override {
2836 MemOp mem = cast<MemOp>(op);
2840 llvm::SmallBitVector deadPorts(mem.getNumResults());
2841 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2843 if (!mem.getPortAnnotation(i).empty())
2847 auto kind = mem.getPortKind(i);
2848 if (kind == MemOp::PortKind::Debug)
2857 if (kind == MemOp::PortKind::Read &&
isPortUnused(port,
"data")) {
2862 if (deadPorts.none())
2866 SmallVector<Type> resultTypes;
2867 SmallVector<StringRef> portNames;
2868 SmallVector<Attribute> portAnnotations;
2869 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2872 resultTypes.push_back(port.getType());
2873 portNames.push_back(mem.getPortName(i));
2874 portAnnotations.push_back(mem.getPortAnnotation(i));
2878 if (!resultTypes.empty())
2879 newOp = MemOp::create(
2880 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
2881 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2882 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2883 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2884 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2887 unsigned nextPort = 0;
2888 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2892 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
2895 rewriter.eraseOp(op);
2901struct FoldReadWritePorts :
public mlir::RewritePattern {
2902 FoldReadWritePorts(MLIRContext *
context)
2903 : RewritePattern(MemOp::getOperationName(), 0,
context) {}
2904 LogicalResult matchAndRewrite(Operation *op,
2905 PatternRewriter &rewriter)
const override {
2906 MemOp mem = cast<MemOp>(op);
2911 llvm::SmallBitVector deadReads(mem.getNumResults());
2912 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2913 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
2915 if (!mem.getPortAnnotation(i).empty())
2922 if (deadReads.none())
2925 SmallVector<Type> resultTypes;
2926 SmallVector<StringRef> portNames;
2927 SmallVector<Attribute> portAnnotations;
2928 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2930 resultTypes.push_back(
2931 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
2932 MemOp::PortKind::Write, mem.getMaskBits()));
2934 resultTypes.push_back(port.getType());
2936 portNames.push_back(mem.getPortName(i));
2937 portAnnotations.push_back(mem.getPortAnnotation(i));
2940 auto newOp = MemOp::create(
2941 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
2942 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2943 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2944 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2945 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2947 for (
unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
2948 auto result = mem.getResult(i);
2949 auto newResult = newOp.getResult(i);
2951 auto resultPortTy = type_cast<BundleType>(result.getType());
2955 auto replace = [&](StringRef toName, StringRef fromName) {
2956 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
2957 assert(fromFieldIndex &&
"missing enable flag on memory port");
2959 auto toField = SubfieldOp::create(rewriter, newResult.getLoc(),
2961 for (
auto *op :
llvm::make_early_inc_range(result.getUsers())) {
2962 auto fromField = cast<SubfieldOp>(op);
2963 if (fromFieldIndex != fromField.getFieldIndex())
2965 rewriter.replaceOp(fromField, toField.getResult());
2969 replace(
"addr",
"addr");
2970 replace(
"en",
"en");
2971 replace(
"clk",
"clk");
2972 replace(
"data",
"wdata");
2973 replace(
"mask",
"wmask");
2976 auto wmodeFieldIndex = resultPortTy.getElementIndex(
"wmode");
2977 for (
auto *op :
llvm::make_early_inc_range(result.getUsers())) {
2978 auto wmodeField = cast<SubfieldOp>(op);
2979 if (wmodeFieldIndex != wmodeField.getFieldIndex())
2981 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
2984 rewriter.replaceAllUsesWith(result, newResult);
2987 rewriter.eraseOp(op);
2993struct FoldUnusedBits :
public mlir::RewritePattern {
2994 FoldUnusedBits(MLIRContext *
context)
2995 : RewritePattern(MemOp::getOperationName(), 0,
context) {}
2997 LogicalResult matchAndRewrite(Operation *op,
2998 PatternRewriter &rewriter)
const override {
2999 MemOp mem = cast<MemOp>(op);
3004 const auto &summary = mem.getSummary();
3005 if (summary.isMasked || summary.isSeqMem())
3008 auto type = type_dyn_cast<IntType>(mem.getDataType());
3011 auto width = type.getBitWidthOrSentinel();
3015 llvm::SmallBitVector usedBits(width);
3016 DenseMap<unsigned, unsigned> mapping;
3021 SmallVector<BitsPrimOp> readOps;
3022 auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
3023 auto portTy = type_cast<BundleType>(port.getType());
3024 auto fieldIndex = portTy.getElementIndex(field);
3025 assert(fieldIndex &&
"missing data port");
3027 for (
auto *op : port.getUsers()) {
3028 auto portAccess = cast<SubfieldOp>(op);
3029 if (fieldIndex != portAccess.getFieldIndex())
3032 for (
auto *user : op->getUsers()) {
3033 auto bits = dyn_cast<BitsPrimOp>(user);
3037 usedBits.set(bits.getLo(), bits.getHi() + 1);
3041 mapping[bits.getLo()] = 0;
3042 readOps.push_back(bits);
3052 SmallVector<MatchingConnectOp> writeOps;
3053 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
3054 auto portTy = type_cast<BundleType>(port.getType());
3055 auto fieldIndex = portTy.getElementIndex(field);
3056 assert(fieldIndex &&
"missing data port");
3058 for (
auto *op : port.getUsers()) {
3059 auto portAccess = cast<SubfieldOp>(op);
3060 if (fieldIndex != portAccess.getFieldIndex())
3067 writeOps.push_back(conn);
3073 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3075 if (!mem.getPortAnnotation(i).empty())
3078 switch (mem.getPortKind(i)) {
3079 case MemOp::PortKind::Debug:
3082 case MemOp::PortKind::Write:
3083 if (failed(findWriteUsers(port,
"data")))
3086 case MemOp::PortKind::Read:
3087 if (failed(findReadUsers(port,
"data")))
3090 case MemOp::PortKind::ReadWrite:
3091 if (failed(findWriteUsers(port,
"wdata")))
3093 if (failed(findReadUsers(port,
"rdata")))
3097 llvm_unreachable(
"unknown port kind");
3101 if (usedBits.none())
3105 SmallVector<std::pair<unsigned, unsigned>> ranges;
3106 unsigned newWidth = 0;
3107 for (
int i = usedBits.find_first(); 0 <= i && i < width;) {
3108 int e = usedBits.find_next_unset(i);
3111 for (
int idx = i; idx < e; ++idx, ++newWidth) {
3112 if (
auto it = mapping.find(idx); it != mapping.end()) {
3113 it->second = newWidth;
3116 ranges.emplace_back(i, e - 1);
3117 i = e != width ? usedBits.find_next(e) : e;
3121 auto newType =
IntType::get(op->getContext(), type.isSigned(), newWidth);
3122 SmallVector<Type> portTypes;
3123 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3124 portTypes.push_back(
3125 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
3127 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
3128 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
3129 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
3130 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
3131 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3134 auto rewriteSubfield = [&](Value port, StringRef field) {
3135 auto portTy = type_cast<BundleType>(port.getType());
3136 auto fieldIndex = portTy.getElementIndex(field);
3137 assert(fieldIndex &&
"missing data port");
3139 rewriter.setInsertionPointAfter(newMem);
3140 auto newPortAccess =
3141 SubfieldOp::create(rewriter, port.getLoc(), port, field);
3143 for (
auto *op :
llvm::make_early_inc_range(port.getUsers())) {
3144 auto portAccess = cast<SubfieldOp>(op);
3145 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
3147 rewriter.replaceOp(portAccess, newPortAccess.getResult());
3152 for (
auto [i, port] :
llvm::enumerate(newMem.getResults())) {
3153 switch (newMem.getPortKind(i)) {
3154 case MemOp::PortKind::Debug:
3155 llvm_unreachable(
"cannot rewrite debug port");
3156 case MemOp::PortKind::Write:
3157 rewriteSubfield(port,
"data");
3159 case MemOp::PortKind::Read:
3160 rewriteSubfield(port,
"data");
3162 case MemOp::PortKind::ReadWrite:
3163 rewriteSubfield(port,
"rdata");
3164 rewriteSubfield(port,
"wdata");
3167 llvm_unreachable(
"unknown port kind");
3171 for (
auto readOp : readOps) {
3172 rewriter.setInsertionPointAfter(readOp);
3173 auto it = mapping.find(readOp.getLo());
3174 assert(it != mapping.end() &&
"bit op mapping not found");
3177 auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
3178 readOp.getLoc(), readOp.getInput(),
3179 readOp.getHi() - readOp.getLo() + it->second, it->second);
3180 rewriter.replaceAllUsesWith(readOp, newReadValue);
3181 rewriter.eraseOp(readOp);
3185 for (
auto writeOp : writeOps) {
3186 Value source = writeOp.getSrc();
3187 rewriter.setInsertionPoint(writeOp);
3189 SmallVector<Value> slices;
3190 for (
auto &[start, end] :
llvm::reverse(ranges)) {
3191 Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
3192 source,
end, start);
3193 slices.push_back(slice);
3197 rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(), slices);
3203 if (type.isSigned())
3205 rewriter.createOrFold<AsSIntPrimOp>(writeOp.getLoc(), catOfSlices);
3207 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
3216struct FoldRegMems :
public mlir::RewritePattern {
3217 FoldRegMems(MLIRContext *
context)
3218 : RewritePattern(MemOp::getOperationName(), 0,
context) {}
3219 LogicalResult matchAndRewrite(Operation *op,
3220 PatternRewriter &rewriter)
const override {
3221 MemOp mem = cast<MemOp>(op);
3226 auto ty = mem.getDataType();
3227 auto loc = mem.getLoc();
3228 auto *block = mem->getBlock();
3232 SmallPtrSet<Operation *, 8> connects;
3233 SmallVector<SubfieldOp> portAccesses;
3234 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3235 if (!mem.getPortAnnotation(i).empty())
3238 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
3239 auto portTy = type_cast<BundleType>(port.getType());
3240 for (
auto field : fields) {
3241 auto fieldIndex = portTy.getElementIndex(field);
3242 assert(fieldIndex &&
"missing field on memory port");
3244 for (
auto *op : port.getUsers()) {
3245 auto portAccess = cast<SubfieldOp>(op);
3246 if (fieldIndex != portAccess.getFieldIndex())
3248 portAccesses.push_back(portAccess);
3249 for (
auto *user : portAccess->getUsers()) {
3250 auto conn = dyn_cast<FConnectLike>(user);
3253 connects.insert(conn);
3260 switch (mem.getPortKind(i)) {
3261 case MemOp::PortKind::Debug:
3263 case MemOp::PortKind::Read:
3264 if (failed(collect({
"clk",
"en",
"addr"})))
3267 case MemOp::PortKind::Write:
3268 if (failed(collect({
"clk",
"en",
"addr",
"data",
"mask"})))
3271 case MemOp::PortKind::ReadWrite:
3272 if (failed(collect({
"clk",
"en",
"addr",
"wmode",
"wdata",
"wmask"})))
3278 if (!portClock || (clock && portClock != clock))
3284 rewriter.setInsertionPointAfter(mem);
3285 auto memWire = WireOp::create(rewriter, loc, ty).getResult();
3291 rewriter.setInsertionPointToEnd(block);
3293 RegOp::create(rewriter, loc, ty, clock, mem.getName()).getResult();
3296 MatchingConnectOp::create(rewriter, loc, memWire, memReg);
3300 auto pipeline = [&](Value value, Value clock,
const Twine &name,
3302 for (
unsigned i = 0; i < latency; ++i) {
3303 std::string regName;
3305 llvm::raw_string_ostream os(regName);
3306 os << mem.getName() <<
"_" << name <<
"_" << i;
3308 auto reg = RegOp::create(rewriter, mem.getLoc(), value.getType(), clock,
3309 rewriter.getStringAttr(regName))
3311 MatchingConnectOp::create(rewriter, value.getLoc(), reg, value);
3317 const unsigned writeStages =
info.writeLatency - 1;
3322 SmallVector<std::tuple<Value, Value, Value>> writes;
3323 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3325 StringRef name = mem.getPortName(i);
3327 auto portPipeline = [&, port = port](StringRef field,
unsigned stages) {
3330 return pipeline(value, portClock, name +
"_" + field, stages);
3333 switch (mem.getPortKind(i)) {
3334 case MemOp::PortKind::Debug:
3335 llvm_unreachable(
"unknown port kind");
3336 case MemOp::PortKind::Read: {
3344 case MemOp::PortKind::Write: {
3345 auto data = portPipeline(
"data", writeStages);
3346 auto en = portPipeline(
"en", writeStages);
3347 auto mask = portPipeline(
"mask", writeStages);
3351 case MemOp::PortKind::ReadWrite: {
3356 auto wdata = portPipeline(
"wdata", writeStages);
3357 auto wmask = portPipeline(
"wmask", writeStages);
3362 auto wen = AndPrimOp::create(rewriter, port.getLoc(),
en,
wmode);
3364 pipeline(wen, portClock, name +
"_wen", writeStages);
3365 writes.emplace_back(
wdata, wenPipelined,
wmask);
3372 Value next = memReg;
3378 Location loc = mem.getLoc();
3379 unsigned maskGran =
info.dataWidth /
info.maskBits;
3380 SmallVector<Value> chunks;
3381 for (
unsigned i = 0; i <
info.maskBits; ++i) {
3382 unsigned hi = (i + 1) * maskGran - 1;
3383 unsigned lo = i * maskGran;
3385 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc,
data, hi, lo);
3386 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3387 auto bit = rewriter.createOrFold<BitsPrimOp>(loc,
mask, i, i);
3388 auto chunk = MuxPrimOp::create(rewriter, loc, bit, dataPart, nextPart);
3389 chunks.push_back(chunk);
3392 std::reverse(chunks.begin(), chunks.end());
3393 masked = rewriter.createOrFold<CatPrimOp>(loc, chunks);
3394 next = MuxPrimOp::create(rewriter, next.getLoc(),
en, masked, next);
3396 Value typedNext = rewriter.createOrFold<BitCastOp>(next.getLoc(), ty, next);
3397 MatchingConnectOp::create(rewriter, memReg.getLoc(), memReg, typedNext);
3400 for (Operation *conn : connects)
3401 rewriter.eraseOp(conn);
3402 for (
auto portAccess : portAccesses)
3403 rewriter.eraseOp(portAccess);
3404 rewriter.eraseOp(mem);
3411void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3414 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3415 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3435 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3438 auto *high = mux.getHigh().getDefiningOp();
3439 auto *low = mux.getLow().getDefiningOp();
3441 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3448 bool constReg =
false;
3450 if (constOp && low == reg)
3452 else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3454 constOp = dyn_cast<ConstantOp>(low);
3461 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3465 auto regTy = reg.getResult().getType();
3466 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3467 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3468 regTy.getBitWidthOrSentinel() < 0)
3474 if (constOp != &con->getBlock()->front())
3475 constOp->moveBefore(&con->getBlock()->front());
3478 SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3479 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3480 rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3481 mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3482 reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3483 reg.getForceableAttr());
3484 newReg->setDialectAttrs(attrs);
3486 auto pt = rewriter.saveInsertionPoint();
3487 rewriter.setInsertionPoint(con);
3488 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3489 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3490 rewriter.restoreInsertionPoint(pt);
3494LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3495 if (!
hasDontTouch(op.getOperation()) && !op.isForceable() &&
3511 PatternRewriter &rewriter,
3514 if (
auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3515 if (constant.getValue().isZero()) {
3516 rewriter.eraseOp(op);
3522 if (
auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3523 if (constant.getValue().isZero() == eraseIfZero) {
3524 rewriter.eraseOp(op);
3532template <
class Op,
bool EraseIfZero = false>
3534 PatternRewriter &rewriter) {
3539void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3541 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3542 results.add<patterns::AssertXWhenX>(
context);
3545void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3547 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3548 results.add<patterns::AssumeXWhenX>(
context);
3551void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3552 RewritePatternSet &results, MLIRContext *
context) {
3553 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3554 results.add<patterns::UnclockedAssumeIntrinsicXWhenX>(
context);
3557void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3559 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3566LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3567 PatternRewriter &rewriter) {
3569 if (op.use_empty()) {
3570 rewriter.eraseOp(op);
3577 if (op->hasOneUse() &&
3578 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3579 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3580 *op->user_begin()) ||
3581 (isa<CvtPrimOp>(*op->user_begin()) &&
3582 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3583 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3584 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3585 .getBitWidthOrSentinel() > 0))) {
3586 auto *modop = *op->user_begin();
3587 auto inv = InvalidValueOp::create(rewriter, op.getLoc(),
3588 modop->getResult(0).getType());
3589 rewriter.replaceAllOpUsesWith(modop, inv);
3590 rewriter.eraseOp(modop);
3591 rewriter.eraseOp(op);
3597OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3598 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3599 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3607OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3616 return BoolAttr::get(getContext(),
false);
3620 return BoolAttr::get(getContext(),
false);
3625LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3626 PatternRewriter &rewriter) {
3628 if (
auto testEnable = op.getTestEnable()) {
3629 if (
auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3630 if (constOp.getValue().isZero()) {
3631 rewriter.modifyOpInPlace(op,
3632 [&] { op.getTestEnableMutable().clear(); });
3648 auto forceable = op.getRef().getDefiningOp<Forceable>();
3649 if (!forceable || !forceable.isForceable() ||
3650 op.getRef() != forceable.getDataRef() ||
3651 op.getType() != forceable.getDataType())
3653 rewriter.replaceAllUsesWith(op, forceable.getData());
3657void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3659 results.insert<patterns::RefResolveOfRefSend>(
context);
3663OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3665 if (getInput().getType() == getType())
3671 auto constOp = operand.getDefiningOp<ConstantOp>();
3672 return constOp && constOp.getValue().isZero();
3675template <
typename Op>
3678 rewriter.eraseOp(op);
3684void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3686 results.add(eraseIfPredFalse<RefForceOp>);
3688void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3690 results.add(eraseIfPredFalse<RefForceInitialOp>);
3692void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3694 results.add(eraseIfPredFalse<RefReleaseOp>);
3696void RefReleaseInitialOp::getCanonicalizationPatterns(
3697 RewritePatternSet &results, MLIRContext *
context) {
3698 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3705OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3711 if (adaptor.getReset())
3716 if (
isUInt1(getReset().getType()) && adaptor.getClock())
3729 [&](
auto ty) ->
bool {
return isTypeEmpty(ty.getElementType()); })
3730 .Case<BundleType>([&](
auto ty) ->
bool {
3731 for (
auto elem : ty.getElements())
3736 .Case<IntType>([&](
auto ty) {
return ty.getWidth() == 0; })
3737 .Default([](
auto) ->
bool {
return false; });
3740LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3741 PatternRewriter &rewriter) {
3742 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3749 rewriter.eraseOp(op);
3757LogicalResult LayerBlockOp::canonicalize(LayerBlockOp op,
3758 PatternRewriter &rewriter) {
3761 if (op.getBody()->empty()) {
3762 rewriter.eraseOp(op);
3773OpFoldResult UnsafeDomainCastOp::fold(FoldAdaptor adaptor) {
3775 if (getDomains().
empty())
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion)
Replaces the given op with the contents of the given single-block region.
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static bool isOkToPropagateName(Operation *op)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
static Location getLoc(DefSlot slot)
static InstancePath empty
AndRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
OrRCat(MLIRContext *context)
bool getIdentityValue() const override
Return the unit value for this reduction operation:
virtual bool getIdentityValue() const =0
Return the unit value for this reduction operation:
virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp constantOp, SmallVectorImpl< Value > &remaining) const =0
Handle a constant operand in the cat operation.
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
XorRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
BitsOfCat(MLIRContext *context)