15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/SCF/IR/SCF.h"
19#include "mlir/IR/AffineExpr.h"
20#include "mlir/IR/AffineMap.h"
21#include "mlir/IR/OperationSupport.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Support/LLVM.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/FormatVariadic.h"
31#define GEN_PASS_DEF_MEMORYBANKING
32#include "circt/Transforms/Passes.h.inc"
42struct MemoryBankingPass
43 :
public circt::impl::MemoryBankingBase<MemoryBankingPass> {
44 MemoryBankingPass(
const MemoryBankingPass &other) =
default;
45 explicit MemoryBankingPass(
46 std::optional<unsigned> bankingFactor = std::nullopt,
47 std::optional<int> bankingDimension = std::nullopt) {}
49 void runOnOperation()
override;
53 DenseMap<Value, SmallVector<Value>> memoryToBanks;
54 DenseSet<Operation *> opsToErase;
57 DenseSet<Value> oldMemRefVals;
63 DenseSet<Value> memrefVals;
64 parOp.walk([&](Operation *op) {
65 for (
auto operand : op->getOperands()) {
66 if (isa<MemRefType>(operand.getType()))
67 memrefVals.insert(operand);
69 return WalkResult::advance();
76 unsigned bankingFactor,
77 MemRefType originalType) {
78 ArrayRef<int64_t> originalShape = originalType.getShape();
79 assert(!originalShape.empty() &&
"memref shape should not be empty");
80 assert(bankingDimension < originalType.getRank() &&
81 "dimension must be within the memref rank");
82 assert(originalShape[bankingDimension] % bankingFactor == 0 &&
83 "memref shape must be evenly divided by the banking factor");
87 uint64_t bankingFactor,
88 unsigned bankingDimension) {
89 ArrayRef<int64_t> originalShape = originalType.getShape();
90 SmallVector<int64_t, 4> newShape(originalShape.begin(), originalShape.end());
91 newShape[bankingDimension] /= bankingFactor;
92 MemRefType newMemRefType =
93 MemRefType::get(newShape, originalType.getElementType(),
94 originalType.getLayout(), originalType.getMemorySpace());
102SmallVector<int64_t>
decodeIndex(int64_t linIndex, ArrayRef<int64_t> shape) {
103 const unsigned rank = shape.size();
104 SmallVector<int64_t> ndIndex(rank, 0);
107 for (int64_t d = rank - 1; d >= 0; --d) {
108 ndIndex[d] = linIndex % shape[d];
109 linIndex /= shape[d];
119SmallVector<SmallVector<Attribute>>
sliceSubBlock(ArrayRef<Attribute> allAttrs,
120 ArrayRef<int64_t> memShape,
121 unsigned bankingDimension,
122 unsigned bankingFactor) {
123 size_t numElements = std::reduce(memShape.begin(), memShape.end(), 1,
124 std::multiplies<size_t>());
127 SmallVector<SmallVector<Attribute>> subBlocks;
128 subBlocks.resize(bankingFactor);
130 for (
unsigned linIndex = 0; linIndex <
numElements; ++linIndex) {
131 SmallVector<int64_t> ndIndex =
decodeIndex(linIndex, memShape);
132 unsigned subBlockIndex = ndIndex[bankingDimension] % bankingFactor;
133 subBlocks[subBlockIndex].push_back(allAttrs[linIndex]);
143 uint64_t bankingFactor,
144 unsigned bankingDimension,
145 MemRefType newMemRefType,
146 OpBuilder &builder) {
147 SmallVector<Value, 4> banks;
148 auto memTy = cast<MemRefType>(getGlobalOp.getType());
149 ArrayRef<int64_t> originalShape = memTy.getShape();
151 SmallVector<int64_t>(originalShape.begin(), originalShape.end());
152 newShape[bankingDimension] = originalShape[bankingDimension] / bankingFactor;
154 auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
155 auto globalOpNameAttr = getGlobalOp.getNameAttr();
156 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
157 SymbolTable::lookupSymbolIn(symbolTableOp, globalOpNameAttr));
158 assert(globalOp &&
"The corresponding GlobalOp should exist in the module");
159 MemRefType globalOpTy = globalOp.getType();
162 dyn_cast_or_null<DenseElementsAttr>(globalOp.getConstantInitValue());
163 auto attributes = cstAttr.getValues<Attribute>();
164 SmallVector<Attribute, 8> allAttrs(attributes.begin(), attributes.end());
167 sliceSubBlock(allAttrs, originalShape, bankingDimension, bankingFactor);
173 builder.setInsertionPointAfter(globalOp);
174 OpBuilder::InsertPoint globalOpsInsertPt = builder.saveInsertionPoint();
175 builder.setInsertionPointAfter(getGlobalOp);
176 OpBuilder::InsertPoint getGlobalOpsInsertPt = builder.saveInsertionPoint();
178 for (
size_t bankCnt = 0; bankCnt < bankingFactor; ++bankCnt) {
180 auto newMemRefTy = MemRefType::get(newShape, globalOpTy.getElementType());
181 auto newTypeAttr = TypeAttr::get(newMemRefTy);
182 std::string newName = llvm::formatv(
183 "{0}_{1}_{2}", globalOpNameAttr.getValue(),
"bank", bankCnt);
184 RankedTensorType tensorType =
185 RankedTensorType::get({newShape}, globalOpTy.getElementType());
186 auto newInitValue = DenseElementsAttr::get(tensorType, subBlocks[bankCnt]);
188 builder.restoreInsertionPoint(globalOpsInsertPt);
189 auto newGlobalOp = builder.create<memref::GlobalOp>(
190 globalOp.getLoc(), builder.getStringAttr(newName),
191 globalOp.getSymVisibilityAttr(), newTypeAttr, newInitValue,
192 globalOp.getConstantAttr(), globalOp.getAlignmentAttr());
193 builder.setInsertionPointAfter(newGlobalOp);
194 globalOpsInsertPt = builder.saveInsertionPoint();
196 builder.restoreInsertionPoint(getGlobalOpsInsertPt);
197 auto newGetGlobalOp = builder.create<memref::GetGlobalOp>(
198 getGlobalOp.getLoc(), newMemRefTy, newGlobalOp.getName());
199 builder.setInsertionPointAfter(newGetGlobalOp);
200 getGlobalOpsInsertPt = builder.saveInsertionPoint();
202 banks.push_back(newGetGlobalOp);
211 ArrayRef<int64_t> shape) {
217 if (bankingDimensionOpt.has_value() && *bankingDimensionOpt >= 0) {
218 return static_cast<unsigned>(*bankingDimensionOpt);
224 int bankingDimension = -1;
225 for (
int dim = rank - 1; dim >= 0; --dim) {
226 if (shape[dim] > 1) {
227 bankingDimension = dim;
232 assert(bankingDimension >= 0 &&
"No eligible dimension for banking");
233 return static_cast<unsigned>(bankingDimension);
239 unsigned &bankingDimension) {
240 if (
auto *originalDef = originalMem.getDefiningOp()) {
241 if (
auto attrFactor = dyn_cast_if_present<IntegerAttr>(
242 originalDef->getAttr(
"banking.factor")))
243 bankingFactor = attrFactor.getInt();
244 if (
auto attrDimension = dyn_cast_if_present<IntegerAttr>(
245 originalDef->getAttr(
"banking.dimension")))
246 bankingDimension = attrDimension.getInt();
251 if (isa<BlockArgument>(originalMem)) {
252 auto blockArg = cast<BlockArgument>(originalMem);
253 auto *parentOp = blockArg.getOwner()->getParentOp();
255 auto funcOp = dyn_cast<func::FuncOp>(parentOp);
257 "Expected the original memory to be a FuncOp block argument!");
259 unsigned argIndex = blockArg.getArgNumber();
260 if (
auto argAttrs = funcOp.getArgAttrDict(argIndex)) {
261 if (
auto attrFactor =
262 dyn_cast_if_present<IntegerAttr>(argAttrs.get(
"banking.factor")))
263 bankingFactor = attrFactor.getInt();
264 if (
auto attrDimension = dyn_cast_if_present<IntegerAttr>(
265 argAttrs.get(
"banking.dimension")))
266 bankingDimension = attrDimension.getInt();
276 MemRefType newMemRefType,
277 unsigned numInsertedArgs) {
278 auto originalArgTypes = funcOp.getFunctionType().getInputs();
279 SmallVector<Type, 4> updatedArgTypes;
283 for (
unsigned i = 0; i < originalArgTypes.size(); ++i) {
284 updatedArgTypes.push_back(originalArgTypes[i]);
288 for (
unsigned j = 0; j < numInsertedArgs; ++j) {
289 updatedArgTypes.push_back(newMemRefType);
295 auto resultTypes = funcOp.getFunctionType().getResults();
297 FunctionType::get(funcOp.getContext(), updatedArgTypes, resultTypes);
298 funcOp.setType(newFuncType);
304 unsigned numInsertedArgs) {
305 ArrayAttr existingArgAttrs = funcOp->getAttrOfType<ArrayAttr>(
"arg_attrs");
306 SmallVector<Attribute, 4> updatedArgAttrs;
307 unsigned numArguments = funcOp.getNumArguments();
308 unsigned newNumArguments = numArguments + numInsertedArgs;
309 updatedArgAttrs.resize(newNumArguments);
312 for (
unsigned i = 0; i < numArguments; ++i) {
314 unsigned newIndex = (i > argIndex) ? i + numInsertedArgs : i;
315 updatedArgAttrs[newIndex] = existingArgAttrs
316 ? existingArgAttrs[i]
317 : DictionaryAttr::get(funcOp.getContext());
321 for (
unsigned i = 0; i < numInsertedArgs; ++i) {
322 updatedArgAttrs[argIndex + 1 + i] =
323 DictionaryAttr::get(funcOp.getContext());
327 funcOp->setAttr(
"arg_attrs",
328 ArrayAttr::get(funcOp.getContext(), updatedArgAttrs));
331SmallVector<Value, 4>
createBanks(Value originalMem,
unsigned bankingFactor,
332 std::optional<int> bankingDimensionOpt) {
333 MemRefType originalMemRefType = cast<MemRefType>(originalMem.getType());
334 unsigned rank = originalMemRefType.getRank();
335 ArrayRef<int64_t> shape = originalMemRefType.getShape();
337 unsigned bankingDimension =
346 originalMemRefType, bankingFactor, bankingDimension);
347 SmallVector<Value, 4> banks;
348 if (
auto blockArgMem = dyn_cast<BlockArgument>(originalMem)) {
349 Block *block = blockArgMem.getOwner();
350 unsigned blockArgNum = blockArgMem.getArgNumber();
352 for (
unsigned i = 0; i < bankingFactor; ++i)
353 block->insertArgument(blockArgNum + 1 + i, newMemRefType,
354 blockArgMem.getLoc());
357 block->getArguments().slice(blockArgNum + 1, bankingFactor);
358 banks.append(blockArgs.begin(), blockArgs.end());
360 auto *parentOp = block->getParentOp();
361 auto funcOp = dyn_cast<func::FuncOp>(parentOp);
362 assert(funcOp &&
"BlockArgument is not part of a FuncOp");
370 Operation *originalDef = originalMem.getDefiningOp();
371 Location loc = originalDef->getLoc();
372 OpBuilder builder(originalDef);
373 builder.setInsertionPointAfter(originalDef);
374 TypeSwitch<Operation *>(originalDef)
375 .Case<memref::AllocOp>([&](memref::AllocOp allocOp) {
376 for (uint64_t bankCnt = 0; bankCnt < bankingFactor; ++bankCnt) {
378 builder.create<memref::AllocOp>(loc, newMemRefType);
379 banks.push_back(bankAllocOp);
382 .Case<memref::AllocaOp>([&](memref::AllocaOp allocaOp) {
383 for (uint64_t bankCnt = 0; bankCnt < bankingFactor; ++bankCnt) {
385 builder.create<memref::AllocaOp>(loc, newMemRefType);
386 banks.push_back(bankAllocaOp);
389 .Case<memref::GetGlobalOp>([&](memref::GetGlobalOp getGlobalOp) {
392 newMemRefType, builder);
393 banks.append(newBanks.begin(), newBanks.end());
395 .Default([](Operation *) {
396 llvm_unreachable(
"Unhandled memory operation type");
414 PatternRewriter &rewriter)
const override {
415 Location loc = loadOp.getLoc();
416 auto originalMem = loadOp.getMemref();
418 auto loadIndices = loadOp.getIndices();
419 MemRefType originalMemRefType = loadOp.getMemRefType();
420 int64_t memrefRank = originalMemRefType.getRank();
421 ArrayRef<int64_t> shape = originalMemRefType.getShape();
423 auto bankingDimension =
431 auto modMap = AffineMap::get(
433 {rewriter.getAffineDimExpr(bankingDimension) %
bankingFactor});
434 auto divMap = AffineMap::get(
436 {rewriter.getAffineDimExpr(bankingDimension).floorDiv(
bankingFactor)});
439 rewriter.create<affine::AffineApplyOp>(loc, modMap, loadIndices);
441 rewriter.create<affine::AffineApplyOp>(loc, divMap, loadIndices);
442 SmallVector<Value, 4> newIndices(loadIndices.begin(), loadIndices.end());
443 newIndices[bankingDimension] = offset;
445 SmallVector<Type> resultTypes = {loadOp.getResult().getType()};
447 SmallVector<int64_t, 4> caseValues;
449 caseValues.push_back(i);
451 rewriter.setInsertionPoint(loadOp);
452 scf::IndexSwitchOp switchOp = rewriter.create<scf::IndexSwitchOp>(
453 loc, resultTypes, bankIndex, caseValues,
457 Region &caseRegion = switchOp.getCaseRegions()[i];
458 rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
459 Value bankedLoad = rewriter.create<mlir::affine::AffineLoadOp>(
460 loc, banks[i], newIndices);
461 rewriter.create<scf::YieldOp>(loc, bankedLoad);
464 Region &defaultRegion = switchOp.getDefaultRegion();
465 assert(defaultRegion.empty() &&
"Default region should be empty");
466 rewriter.setInsertionPointToStart(&defaultRegion.emplaceBlock());
469 cast<TypedAttr>(rewriter.getZeroAttr(loadOp.getType()));
470 auto defaultValue = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
471 rewriter.create<scf::YieldOp>(loc, defaultValue.getResult());
475 if (Value memRef = loadOp.getMemref(); isa<BlockArgument>(memRef))
477 rewriter.replaceOp(loadOp, switchOp.getResult(0));
504 PatternRewriter &rewriter)
const override {
508 Location loc = storeOp.getLoc();
509 auto originalMem = storeOp.getMemref();
511 auto storeIndices = storeOp.getIndices();
512 auto originalMemRefType = storeOp.getMemRefType();
513 int64_t memrefRank = originalMemRefType.getRank();
514 ArrayRef<int64_t> shape = originalMemRefType.getShape();
516 auto bankingDimension =
524 auto modMap = AffineMap::get(
526 {rewriter.getAffineDimExpr(bankingDimension) %
bankingFactor});
527 auto divMap = AffineMap::get(
529 {rewriter.getAffineDimExpr(bankingDimension).floorDiv(
bankingFactor)});
532 rewriter.create<affine::AffineApplyOp>(loc, modMap, storeIndices);
534 rewriter.create<affine::AffineApplyOp>(loc, divMap, storeIndices);
535 SmallVector<Value, 4> newIndices(storeIndices.begin(), storeIndices.end());
536 newIndices[bankingDimension] = offset;
538 SmallVector<Type> resultTypes = {};
540 SmallVector<int64_t, 4> caseValues;
542 caseValues.push_back(i);
544 rewriter.setInsertionPoint(storeOp);
545 scf::IndexSwitchOp switchOp = rewriter.create<scf::IndexSwitchOp>(
546 loc, resultTypes, bankIndex, caseValues,
550 Region &caseRegion = switchOp.getCaseRegions()[i];
551 rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
552 rewriter.create<mlir::affine::AffineStoreOp>(
553 loc, storeOp.getValueToStore(), banks[i], newIndices);
554 rewriter.create<scf::YieldOp>(loc);
557 Region &defaultRegion = switchOp.getDefaultRegion();
558 assert(defaultRegion.empty() &&
"Default region should be empty");
559 rewriter.setInsertionPointToStart(&defaultRegion.emplaceBlock());
561 rewriter.create<scf::YieldOp>(loc);
587 PatternRewriter &rewriter)
const override {
588 Location loc = returnOp.getLoc();
589 SmallVector<Value, 4> newReturnOperands;
590 bool allOrigMemsUsedByReturn =
true;
591 for (
auto operand : returnOp.getOperands()) {
593 newReturnOperands.push_back(operand);
596 if (operand.hasOneUse())
597 allOrigMemsUsedByReturn =
false;
599 newReturnOperands.append(banks.begin(), banks.end());
602 func::FuncOp funcOp = returnOp.getParentOp();
603 rewriter.setInsertionPointToEnd(&funcOp.getBlocks().front());
605 rewriter.create<func::ReturnOp>(loc, ValueRange(newReturnOperands));
606 TypeRange newReturnType = TypeRange(newReturnOperands);
607 FunctionType newFuncType =
608 FunctionType::get(funcOp.getContext(),
609 funcOp.getFunctionType().getInputs(), newReturnType);
610 funcOp.setType(newFuncType);
612 if (allOrigMemsUsedByReturn)
613 rewriter.replaceOp(returnOp, newReturnOp);
627 DenseSet<Operation *> &opsToErase) {
628 DenseSet<func::FuncOp> funcsToModify;
629 SmallVector<Value, 4> valuesToErase;
630 DenseMap<func::FuncOp, SmallVector<unsigned, 4>> erasedArgIndices;
631 for (
auto &memrefVal : oldMemRefVals) {
632 valuesToErase.push_back(memrefVal);
633 if (
auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
635 dyn_cast<func::FuncOp>(blockArg.getOwner()->getParentOp())) {
636 funcsToModify.insert(funcOp);
637 erasedArgIndices[funcOp].push_back(blockArg.getArgNumber());
642 for (
auto *op : opsToErase) {
646 for (
auto &memrefVal : valuesToErase) {
647 assert(memrefVal.use_empty() &&
"use must be empty");
648 if (
auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
649 blockArg.getOwner()->eraseArgument(blockArg.getArgNumber());
650 }
else if (
auto *op = memrefVal.getDefiningOp()) {
656 for (
auto funcOp : funcsToModify) {
657 ArrayAttr existingArgAttrs = funcOp->getAttrOfType<ArrayAttr>(
"arg_attrs");
658 if (existingArgAttrs) {
659 SmallVector<Attribute, 4> updatedArgAttrs;
660 auto erasedIndices = erasedArgIndices[funcOp];
661 DenseSet<unsigned> indicesToErase(erasedIndices.begin(),
662 erasedIndices.end());
663 for (
unsigned i = 0; i < existingArgAttrs.size(); ++i) {
664 if (!indicesToErase.contains(i))
665 updatedArgAttrs.push_back(existingArgAttrs[i]);
668 funcOp->setAttr(
"arg_attrs",
669 ArrayAttr::get(funcOp.getContext(), updatedArgAttrs));
672 SmallVector<Type, 4> newArgTypes;
673 for (BlockArgument arg : funcOp.getArguments()) {
674 newArgTypes.push_back(arg.getType());
676 FunctionType newFuncType =
677 FunctionType::get(funcOp.getContext(), newArgTypes,
678 funcOp.getFunctionType().getResults());
679 funcOp.setType(newFuncType);
685void MemoryBankingPass::runOnOperation() {
686 if (getOperation().isExternal() || bankingFactor == 1)
689 if (bankingFactor == 0) {
690 getOperation().emitError(
"banking factor must be greater than 1");
695 getOperation().walk([&](mlir::affine::AffineParallelOp parOp) {
698 for (
auto memrefVal : memrefsInPar) {
699 auto [it, inserted] =
700 memoryToBanks.insert(std::make_pair(memrefVal, SmallVector<Value>{}));
702 it->second =
createBanks(memrefVal, bankingFactor, bankingDimension);
706 auto *ctx = &getContext();
709 DenseSet<Operation *> processedOps;
711 memoryToBanks, oldMemRefVals);
713 memoryToBanks, opsToErase, processedOps,
717 GreedyRewriteConfig config;
718 config.strictMode = GreedyRewriteStrictness::ExistingOps;
720 applyPatternsGreedily(getOperation(), std::move(
patterns), config))) {
731std::unique_ptr<mlir::Pass>
733 std::optional<int> bankingDimension) {
734 return std::make_unique<MemoryBankingPass>(bankingFactor, bankingDimension);
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
SmallVector< Value, 4 > handleGetGlobalOp(memref::GetGlobalOp getGlobalOp, uint64_t bankingFactor, unsigned bankingDimension, MemRefType newMemRefType, OpBuilder &builder)
DenseSet< Value > collectMemRefs(mlir::affine::AffineParallelOp parOp)
void verifyBankingConfigurations(unsigned bankingDimension, unsigned bankingFactor, MemRefType originalType)
SmallVector< Value, 4 > createBanks(Value originalMem, unsigned bankingFactor, std::optional< int > bankingDimensionOpt)
void updateFuncOpArgAttrs(func::FuncOp funcOp, unsigned argIndex, unsigned numInsertedArgs)
SmallVector< int64_t > decodeIndex(int64_t linIndex, ArrayRef< int64_t > shape)
LogicalResult cleanUpOldMemRefs(DenseSet< Value > &oldMemRefVals, DenseSet< Operation * > &opsToErase)
void updateFuncOpArgumentTypes(func::FuncOp funcOp, unsigned argIndex, MemRefType newMemRefType, unsigned numInsertedArgs)
unsigned getSpecifiedOrDefaultBankingDim(std::optional< int > bankingDimensionOpt, int64_t rank, ArrayRef< int64_t > shape)
MemRefType computeBankedMemRefType(MemRefType originalType, uint64_t bankingFactor, unsigned bankingDimension)
SmallVector< SmallVector< Attribute > > sliceSubBlock(ArrayRef< Attribute > allAttrs, ArrayRef< int64_t > memShape, unsigned bankingDimension, unsigned bankingFactor)
void resolveBankingAttributes(Value originalMem, unsigned &bankingFactor, unsigned &bankingDimension)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createMemoryBankingPass(std::optional< unsigned > bankingFactor=std::nullopt, std::optional< int > bankingDimension=std::nullopt)
DenseMap< Value, SmallVector< Value > > & memoryToBanks
std::optional< int > bankingDimensionOpt
LogicalResult matchAndRewrite(mlir::affine::AffineLoadOp loadOp, PatternRewriter &rewriter) const override
BankAffineLoadPattern(MLIRContext *context, uint64_t bankingFactor, std::optional< int > bankingDimensionOpt, DenseMap< Value, SmallVector< Value > > &memoryToBanks, DenseSet< Value > &oldMemRefVals)
DenseSet< Value > & oldMemRefVals
BankAffineStorePattern(MLIRContext *context, uint64_t bankingFactor, std::optional< int > bankingDimensionOpt, DenseMap< Value, SmallVector< Value > > &memoryToBanks, DenseSet< Operation * > &opsToErase, DenseSet< Operation * > &processedOps, DenseSet< Value > &oldMemRefVals)
DenseMap< Value, SmallVector< Value > > & memoryToBanks
DenseSet< Value > & oldMemRefVals
DenseSet< Operation * > & processedOps
DenseSet< Operation * > & opsToErase
LogicalResult matchAndRewrite(mlir::affine::AffineStoreOp storeOp, PatternRewriter &rewriter) const override
std::optional< int > bankingDimensionOpt
DenseMap< Value, SmallVector< Value > > & memoryToBanks
BankReturnPattern(MLIRContext *context, DenseMap< Value, SmallVector< Value > > &memoryToBanks)
LogicalResult matchAndRewrite(func::ReturnOp returnOp, PatternRewriter &rewriter) const override