15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/SCF/IR/SCF.h"
19#include "mlir/IR/AffineExpr.h"
20#include "mlir/IR/AffineMap.h"
21#include "mlir/IR/OperationSupport.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Support/LLVM.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/FormatVariadic.h"
31#define GEN_PASS_DEF_MEMORYBANKING
32#include "circt/Transforms/Passes.h.inc"
39struct BankingConfigAttributes {
44constexpr std::string_view bankingFactorsStr =
"banking.factors";
45constexpr std::string_view bankingDimensionsStr =
"banking.dimensions";
49struct MemoryBankingPass
50 :
public circt::impl::MemoryBankingBase<MemoryBankingPass> {
51 MemoryBankingPass(
const MemoryBankingPass &other) =
default;
52 explicit MemoryBankingPass(ArrayRef<unsigned> bankingFactors = {},
53 ArrayRef<unsigned> bankingDimensions = {}) {}
55 void runOnOperation()
override;
57 LogicalResult applyMemoryBanking(Operation *, MLIRContext *);
59 SmallVector<Value, 4> createBanks(OpBuilder &builder, Value originalMem);
61 void setAllBankingAttributes(Operation *, MLIRContext *);
64 SmallVector<unsigned, 4> bankingFactors;
65 SmallVector<unsigned, 4> bankingDimensions;
67 DenseMap<Value, SmallVector<Value>> memoryToBanks;
68 DenseSet<Operation *> opsToErase;
71 DenseSet<Value> oldMemRefVals;
76 Attribute bankingFactorsAttr, bankingDimensionsAttr;
77 if (
auto blockArg = dyn_cast<BlockArgument>(originalMem)) {
78 Block *block = blockArg.getOwner();
80 auto *parentOp = block->getParentOp();
81 auto funcOp = dyn_cast<func::FuncOp>(parentOp);
83 "Expected the original memory to be a FuncOp block argument!");
84 unsigned argIndex = blockArg.getArgNumber();
85 if (
auto argAttrs = funcOp.getArgAttrDict(argIndex)) {
86 bankingFactorsAttr = argAttrs.get(bankingFactorsStr);
87 bankingDimensionsAttr = argAttrs.get(bankingDimensionsStr);
90 Operation *originalDef = originalMem.getDefiningOp();
91 bankingFactorsAttr = originalDef->getAttr(bankingFactorsStr);
92 bankingDimensionsAttr = originalDef->getAttr(bankingDimensionsStr);
94 return BankingConfigAttributes{bankingFactorsAttr, bankingDimensionsAttr};
99 DenseSet<Value> memrefVals;
100 affineParallelOp.walk([&](Operation *op) {
101 if (!isa<affine::AffineWriteOpInterface>(op) &&
102 !isa<affine::AffineReadOpInterface>(op))
103 return WalkResult::advance();
105 auto read = dyn_cast<affine::AffineReadOpInterface>(op);
106 Value memref = read ? read.getMemRef()
107 : cast<affine::AffineWriteOpInterface>(op).getMemRef();
108 memrefVals.insert(memref);
109 return WalkResult::advance();
116 unsigned bankingDimension,
117 MemRefType originalType) {
118 ArrayRef<int64_t> originalShape = originalType.getShape();
119 assert(!originalShape.empty() &&
"memref shape should not be empty");
120 assert(bankingDimension < originalType.getRank() &&
121 "dimension must be within the memref rank");
122 assert(originalShape[bankingDimension] % bankingFactor == 0 &&
123 "memref shape must be evenly divided by the banking factor");
127 uint64_t bankingFactor,
128 unsigned bankingDimension) {
129 ArrayRef<int64_t> originalShape = originalType.getShape();
130 SmallVector<int64_t, 4> newShape(originalShape.begin(), originalShape.end());
131 newShape[bankingDimension] /= bankingFactor;
132 MemRefType newMemRefType =
133 MemRefType::get(newShape, originalType.getElementType(),
134 originalType.getLayout(), originalType.getMemorySpace());
136 return newMemRefType;
142SmallVector<int64_t>
decodeIndex(int64_t linIndex, ArrayRef<int64_t> shape) {
143 const unsigned rank = shape.size();
144 SmallVector<int64_t> ndIndex(rank, 0);
147 for (int64_t d = rank - 1; d >= 0; --d) {
148 ndIndex[d] = linIndex % shape[d];
149 linIndex /= shape[d];
142SmallVector<int64_t>
decodeIndex(int64_t linIndex, ArrayRef<int64_t> shape) {
…}
159SmallVector<SmallVector<Attribute>>
sliceSubBlock(ArrayRef<Attribute> allAttrs,
160 ArrayRef<int64_t> memShape,
161 unsigned bankingDimension,
162 unsigned bankingFactor) {
163 size_t numElements = std::reduce(memShape.begin(), memShape.end(), 1,
164 std::multiplies<size_t>());
167 SmallVector<SmallVector<Attribute>> subBlocks;
168 subBlocks.resize(bankingFactor);
170 for (
unsigned linIndex = 0; linIndex <
numElements; ++linIndex) {
171 SmallVector<int64_t> ndIndex =
decodeIndex(linIndex, memShape);
172 unsigned subBlockIndex = ndIndex[bankingDimension] % bankingFactor;
173 subBlocks[subBlockIndex].push_back(allAttrs[linIndex]);
159SmallVector<SmallVector<Attribute>>
sliceSubBlock(ArrayRef<Attribute> allAttrs, {
…}
184 unsigned bankingDimension, MemRefType newMemRefType,
185 OpBuilder &builder, DictionaryAttr remainingAttrs) {
186 SmallVector<Value, 4> banks;
187 auto memTy = cast<MemRefType>(getGlobalOp.getType());
188 ArrayRef<int64_t> originalShape = memTy.getShape();
190 SmallVector<int64_t>(originalShape.begin(), originalShape.end());
191 newShape[bankingDimension] = originalShape[bankingDimension] / bankingFactor;
193 auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
194 auto globalOpNameAttr = getGlobalOp.getNameAttr();
195 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
196 SymbolTable::lookupSymbolIn(symbolTableOp, globalOpNameAttr));
197 assert(globalOp &&
"The corresponding GlobalOp should exist in the module");
198 MemRefType globalOpTy = globalOp.getType();
201 dyn_cast_or_null<DenseElementsAttr>(globalOp.getConstantInitValue());
202 auto attributes = cstAttr.getValues<Attribute>();
203 SmallVector<Attribute, 8> allAttrs(attributes.begin(), attributes.end());
206 sliceSubBlock(allAttrs, originalShape, bankingDimension, bankingFactor);
212 builder.setInsertionPointAfter(globalOp);
213 OpBuilder::InsertPoint globalOpsInsertPt = builder.saveInsertionPoint();
214 builder.setInsertionPointAfter(getGlobalOp);
215 OpBuilder::InsertPoint getGlobalOpsInsertPt = builder.saveInsertionPoint();
217 for (
size_t bankCnt = 0; bankCnt < bankingFactor; ++bankCnt) {
219 auto newMemRefTy = MemRefType::get(newShape, globalOpTy.getElementType());
220 auto newTypeAttr = TypeAttr::get(newMemRefTy);
221 std::string newName = llvm::formatv(
222 "{0}_{1}_{2}", globalOpNameAttr.getValue(),
"bank", bankCnt);
223 RankedTensorType tensorType =
224 RankedTensorType::get({newShape}, globalOpTy.getElementType());
225 auto newInitValue = DenseElementsAttr::get(tensorType, subBlocks[bankCnt]);
227 builder.restoreInsertionPoint(globalOpsInsertPt);
228 auto newGlobalOp = builder.create<memref::GlobalOp>(
229 globalOp.getLoc(), builder.getStringAttr(newName),
230 globalOp.getSymVisibilityAttr(), newTypeAttr, newInitValue,
231 globalOp.getConstantAttr(), globalOp.getAlignmentAttr());
232 builder.setInsertionPointAfter(newGlobalOp);
233 globalOpsInsertPt = builder.saveInsertionPoint();
235 builder.restoreInsertionPoint(getGlobalOpsInsertPt);
236 auto newGetGlobalOp = builder.create<memref::GetGlobalOp>(
237 getGlobalOp.getLoc(), newMemRefTy, newGlobalOp.getName());
238 newGetGlobalOp->setAttrs(remainingAttrs);
239 builder.setInsertionPointAfter(newGetGlobalOp);
240 getGlobalOpsInsertPt = builder.saveInsertionPoint();
242 banks.push_back(newGetGlobalOp);
249SmallVector<unsigned, 4>
251 int64_t rank, ArrayRef<int64_t> shape) {
253 if (!bankingDimensions.empty()) {
254 return SmallVector<unsigned, 4>(bankingDimensions.begin(),
255 bankingDimensions.end());
261 int bankingDimension = -1;
262 for (
int dim = rank - 1; dim >= 0; --dim) {
263 if (shape[dim] > 1) {
264 bankingDimension = dim;
269 assert(bankingDimension >= 0 &&
"No eligible dimension for banking");
270 return SmallVector<unsigned, 4>{
static_cast<unsigned>(bankingDimension)};
276 MemRefType newMemRefType,
277 unsigned numInsertedArgs) {
278 auto originalArgTypes = funcOp.getFunctionType().getInputs();
279 SmallVector<Type, 4> updatedArgTypes;
283 for (
unsigned i = 0; i < originalArgTypes.size(); ++i) {
284 updatedArgTypes.push_back(originalArgTypes[i]);
288 for (
unsigned j = 0; j < numInsertedArgs; ++j) {
289 updatedArgTypes.push_back(newMemRefType);
295 auto resultTypes = funcOp.getFunctionType().getResults();
297 FunctionType::get(funcOp.getContext(), updatedArgTypes, resultTypes);
298 funcOp.setType(newFuncType);
304 unsigned numInsertedArgs,
305 DictionaryAttr remainingAttrs) {
306 ArrayAttr existingArgAttrs = funcOp->getAttrOfType<ArrayAttr>(
"arg_attrs");
307 SmallVector<Attribute, 4> updatedArgAttrs;
308 unsigned numArguments = funcOp.getNumArguments();
309 unsigned newNumArguments = numArguments + numInsertedArgs;
310 updatedArgAttrs.resize(newNumArguments);
313 for (
unsigned i = 0; i < numArguments; ++i) {
315 unsigned newIndex = (i > argIndex) ? i + numInsertedArgs : i;
316 updatedArgAttrs[newIndex] = existingArgAttrs
317 ? existingArgAttrs[i]
318 : DictionaryAttr::get(funcOp.getContext());
322 for (
unsigned i = 0; i < numInsertedArgs; ++i) {
323 updatedArgAttrs[argIndex + 1 + i] = remainingAttrs;
327 funcOp->setAttr(
"arg_attrs",
328 ArrayAttr::get(funcOp.getContext(), updatedArgAttrs));
332 StringRef attrName) {
333 auto getFirstInteger = [](Attribute attr) ->
unsigned {
334 if (
auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
335 assert(!arrayAttr.empty() &&
336 "BankingConfig ArrayAttr should not be empty");
337 auto intAttr = dyn_cast<IntegerAttr>(arrayAttr.getValue().front());
338 assert(intAttr &&
"BankingConfig elements must be integers");
339 return intAttr.getInt();
341 auto intAttr = dyn_cast<IntegerAttr>(attr);
342 assert(intAttr &&
"BankingConfig attribute must be an integer");
343 return intAttr.getInt();
346 if (attrName.str() == bankingFactorsStr) {
347 return getFirstInteger(bankingConfigAttrs.factors);
350 assert(attrName.str() == bankingDimensionsStr &&
351 "BankingConfig only contains 'factors' and 'dimensions' attributes");
352 return getFirstInteger(bankingConfigAttrs.dimensions);
356 BankingConfigAttributes bankingConfigAttrs,
357 StringRef attrName) {
358 auto getRemainingElements = [context](Attribute attr) -> Attribute {
359 if (
auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
360 assert(!arrayAttr.empty() &&
361 "BankingConfig ArrayAttr should not be empty");
362 return arrayAttr.size() > 1
363 ? ArrayAttr::get(context, arrayAttr.getValue().take_back(
364 arrayAttr.size() - 1))
367 assert(dyn_cast<IntegerAttr>(attr) &&
368 "BankingConfig attribute must be an integer");
372 if (attrName.str() == bankingFactorsStr) {
373 return getRemainingElements(bankingConfigAttrs.factors);
376 assert(attrName.str() == bankingDimensionsStr &&
377 "BankingConfig only contains 'factors' and 'dimensions' attributes");
378 return getRemainingElements(bankingConfigAttrs.dimensions);
381SmallVector<Value, 4> MemoryBankingPass::createBanks(OpBuilder &builder,
383 MemRefType originalMemRefType = cast<MemRefType>(originalMem.getType());
385 MLIRContext *context = builder.getContext();
387 BankingConfigAttributes currBankingConfig =
390 unsigned currFactor =
392 unsigned currDimension =
397 Attribute remainingFactors =
399 Attribute remainingDimensions =
401 DictionaryAttr remainingAttrs =
403 ? DictionaryAttr::get(
405 {builder.getNamedAttr(bankingFactorsStr, remainingFactors),
406 builder.getNamedAttr(bankingDimensionsStr,
407 remainingDimensions)})
408 : DictionaryAttr::
get(context);
410 MemRefType newMemRefType =
412 SmallVector<Value, 4> banks;
413 if (
auto blockArgMem = dyn_cast<BlockArgument>(originalMem)) {
414 Block *block = blockArgMem.getOwner();
415 unsigned blockArgNum = blockArgMem.getArgNumber();
417 for (
unsigned i = 0; i < currFactor; ++i)
418 block->insertArgument(blockArgNum + 1 + i, newMemRefType,
419 blockArgMem.getLoc());
421 auto blockArgs = block->getArguments().slice(blockArgNum + 1, currFactor);
422 banks.append(blockArgs.begin(), blockArgs.end());
424 auto *parentOp = block->getParentOp();
425 auto funcOp = dyn_cast<func::FuncOp>(parentOp);
426 assert(funcOp &&
"BlockArgument is not part of a FuncOp");
433 Operation *originalDef = originalMem.getDefiningOp();
434 Location loc = originalDef->getLoc();
435 builder.setInsertionPointAfter(originalDef);
436 TypeSwitch<Operation *>(originalDef)
437 .Case<memref::AllocOp>([&](memref::AllocOp allocOp) {
438 for (uint64_t bankCnt = 0; bankCnt < currFactor; ++bankCnt) {
440 builder.create<memref::AllocOp>(loc, newMemRefType);
441 bankAllocOp->setAttrs(remainingAttrs);
442 banks.push_back(bankAllocOp);
445 .Case<memref::AllocaOp>([&](memref::AllocaOp allocaOp) {
446 for (uint64_t bankCnt = 0; bankCnt < currFactor; ++bankCnt) {
448 builder.create<memref::AllocaOp>(loc, newMemRefType);
449 bankAllocaOp->setAttrs(remainingAttrs);
450 banks.push_back(bankAllocaOp);
453 .Case<memref::GetGlobalOp>([&](memref::GetGlobalOp getGlobalOp) {
456 newMemRefType, builder, remainingAttrs);
457 banks.append(newBanks.begin(), newBanks.end());
459 .Default([](Operation *) {
460 llvm_unreachable(
"Unhandled memory operation type");
477 PatternRewriter &rewriter)
const override {
478 Location loc = loadOp.getLoc();
479 auto originalMem = loadOp.getMemref();
481 auto loadIndices = loadOp.getIndices();
482 MemRefType originalMemRefType = loadOp.getMemRefType();
483 int64_t memrefRank = originalMemRefType.getRank();
485 BankingConfigAttributes currBankingConfig =
487 if (!currBankingConfig.factors) {
492 unsigned currFactor =
494 unsigned currDimension =
499 auto modMap = AffineMap::get(
501 {rewriter.getAffineDimExpr(currDimension) % currFactor});
502 auto divMap = AffineMap::get(
504 {rewriter.getAffineDimExpr(currDimension).floorDiv(currFactor)});
507 rewriter.create<affine::AffineApplyOp>(loc, modMap, loadIndices);
509 rewriter.create<affine::AffineApplyOp>(loc, divMap, loadIndices);
510 SmallVector<Value, 4> newIndices(loadIndices.begin(), loadIndices.end());
511 newIndices[currDimension] = offset;
513 SmallVector<Type> resultTypes = {loadOp.getResult().getType()};
515 SmallVector<int64_t, 4> caseValues;
516 for (
unsigned i = 0; i < currFactor; ++i)
517 caseValues.push_back(i);
519 rewriter.setInsertionPoint(loadOp);
520 scf::IndexSwitchOp switchOp = rewriter.create<scf::IndexSwitchOp>(
521 loc, resultTypes, bankIndex, caseValues,
524 for (
unsigned i = 0; i < currFactor; ++i) {
525 Region &caseRegion = switchOp.getCaseRegions()[i];
526 rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
527 Value bankedLoad = rewriter.create<mlir::affine::AffineLoadOp>(
528 loc, banks[i], newIndices);
529 rewriter.create<scf::YieldOp>(loc, bankedLoad);
532 Region &defaultRegion = switchOp.getDefaultRegion();
533 assert(defaultRegion.empty() &&
"Default region should be empty");
534 rewriter.setInsertionPointToStart(&defaultRegion.emplaceBlock());
537 cast<TypedAttr>(rewriter.getZeroAttr(loadOp.getType()));
538 auto defaultValue = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
539 rewriter.create<scf::YieldOp>(loc, defaultValue.getResult());
543 if (Value memRef = loadOp.getMemref(); isa<BlockArgument>(memRef)) {
546 rewriter.replaceOp(loadOp, switchOp.getResult(0));
569 PatternRewriter &rewriter)
const override {
574 if (!currConfig.factors) {
578 Location loc = storeOp.getLoc();
579 auto originalMem = storeOp.getMemref();
581 auto storeIndices = storeOp.getIndices();
582 auto originalMemRefType = storeOp.getMemRefType();
583 int64_t memrefRank = originalMemRefType.getRank();
585 BankingConfigAttributes currBankingConfig =
588 unsigned currFactor =
590 unsigned currDimension =
595 auto modMap = AffineMap::get(
597 {rewriter.getAffineDimExpr(currDimension) % currFactor});
598 auto divMap = AffineMap::get(
600 {rewriter.getAffineDimExpr(currDimension).floorDiv(currFactor)});
603 rewriter.create<affine::AffineApplyOp>(loc, modMap, storeIndices);
605 rewriter.create<affine::AffineApplyOp>(loc, divMap, storeIndices);
606 SmallVector<Value, 4> newIndices(storeIndices.begin(), storeIndices.end());
607 newIndices[currDimension] = offset;
609 SmallVector<Type> resultTypes = {};
611 SmallVector<int64_t, 4> caseValues;
612 for (
unsigned i = 0; i < currFactor; ++i)
613 caseValues.push_back(i);
615 rewriter.setInsertionPoint(storeOp);
616 scf::IndexSwitchOp switchOp = rewriter.create<scf::IndexSwitchOp>(
617 loc, resultTypes, bankIndex, caseValues,
620 for (
unsigned i = 0; i < currFactor; ++i) {
621 Region &caseRegion = switchOp.getCaseRegions()[i];
622 rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
623 rewriter.create<mlir::affine::AffineStoreOp>(
624 loc, storeOp.getValueToStore(), banks[i], newIndices);
625 rewriter.create<scf::YieldOp>(loc);
628 Region &defaultRegion = switchOp.getDefaultRegion();
629 assert(defaultRegion.empty() &&
"Default region should be empty");
630 rewriter.setInsertionPointToStart(&defaultRegion.emplaceBlock());
632 rewriter.create<scf::YieldOp>(loc);
656 PatternRewriter &rewriter)
const override {
657 Location loc = returnOp.getLoc();
658 SmallVector<Value, 4> newReturnOperands;
659 bool allOrigMemsUsedByReturn =
true;
660 for (
auto operand : returnOp.getOperands()) {
662 newReturnOperands.push_back(operand);
665 if (operand.hasOneUse())
666 allOrigMemsUsedByReturn =
false;
668 newReturnOperands.append(banks.begin(), banks.end());
671 func::FuncOp funcOp = returnOp.getParentOp();
672 rewriter.setInsertionPointToEnd(&funcOp.getBlocks().front());
674 rewriter.create<func::ReturnOp>(loc, ValueRange(newReturnOperands));
675 TypeRange newReturnType = TypeRange(newReturnOperands);
676 FunctionType newFuncType =
677 FunctionType::get(funcOp.getContext(),
678 funcOp.getFunctionType().getInputs(), newReturnType);
679 funcOp.setType(newFuncType);
681 if (allOrigMemsUsedByReturn)
682 rewriter.replaceOp(returnOp, newReturnOp);
696 DenseSet<Operation *> &opsToErase) {
697 DenseSet<func::FuncOp> funcsToModify;
698 SmallVector<Value, 4> valuesToErase;
699 DenseMap<func::FuncOp, SmallVector<unsigned, 4>> erasedArgIndices;
700 for (
auto &memrefVal : oldMemRefVals) {
701 valuesToErase.push_back(memrefVal);
702 if (
auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
704 dyn_cast<func::FuncOp>(blockArg.getOwner()->getParentOp())) {
705 funcsToModify.insert(funcOp);
706 erasedArgIndices[funcOp].push_back(blockArg.getArgNumber());
711 for (
auto *op : opsToErase) {
715 for (
auto &memrefVal : valuesToErase) {
716 assert(memrefVal.use_empty() &&
"use must be empty");
717 if (
auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
718 blockArg.getOwner()->eraseArgument(blockArg.getArgNumber());
719 }
else if (
auto *op = memrefVal.getDefiningOp()) {
725 for (
auto funcOp : funcsToModify) {
726 ArrayAttr existingArgAttrs = funcOp->getAttrOfType<ArrayAttr>(
"arg_attrs");
727 if (existingArgAttrs) {
728 SmallVector<Attribute, 4> updatedArgAttrs;
729 auto erasedIndices = erasedArgIndices[funcOp];
730 DenseSet<unsigned> indicesToErase(erasedIndices.begin(),
731 erasedIndices.end());
732 for (
unsigned i = 0; i < existingArgAttrs.size(); ++i) {
733 if (!indicesToErase.contains(i))
734 updatedArgAttrs.push_back(existingArgAttrs[i]);
737 funcOp->setAttr(
"arg_attrs",
738 ArrayAttr::get(funcOp.getContext(), updatedArgAttrs));
741 SmallVector<Type, 4> newArgTypes;
742 for (BlockArgument arg : funcOp.getArguments()) {
743 newArgTypes.push_back(arg.getType());
745 FunctionType newFuncType =
746 FunctionType::get(funcOp.getContext(), newArgTypes,
747 funcOp.getFunctionType().getResults());
748 funcOp.setType(newFuncType);
755 Attribute bankingDimensionsAttr) {
756 if (
auto factorsArrayAttr = dyn_cast<ArrayAttr>(bankingFactorsAttr)) {
757 assert(!factorsArrayAttr.empty() &&
"Banking factors should not be empty");
758 if (
auto dimsArrayAttr = dyn_cast<ArrayAttr>(bankingDimensionsAttr)) {
759 assert(factorsArrayAttr.size() == dimsArrayAttr.size() &&
760 "Banking factors/dimensions must be paired together");
762 auto dimsIntAttr = dyn_cast<IntegerAttr>(bankingDimensionsAttr);
763 assert(dimsIntAttr &&
"banking.dimensions can either be an integer or an "
764 "array of integers");
765 assert(factorsArrayAttr.size() == 1 &&
766 "Banking factors/dimensions must be paired together");
769 auto factorsIntAttr = dyn_cast<IntegerAttr>(bankingFactorsAttr);
771 "banking.factors can either be an integer or an array of integers");
772 if (
auto dimsArrayAttr = dyn_cast<ArrayAttr>(bankingDimensionsAttr)) {
773 assert(dimsArrayAttr.size() == 1 &&
774 "Banking factors/dimensions must be paired together");
776 auto dimsIntAttr = dyn_cast<IntegerAttr>(bankingDimensionsAttr);
777 assert(dimsIntAttr &&
"banking.dimensions can either be an integer or an "
778 "array of integers");
783void MemoryBankingPass::setAllBankingAttributes(Operation *operation,
784 MLIRContext *context) {
785 ArrayAttr defaultFactorsAttr = ArrayAttr::get(
787 llvm::map_to_vector(bankingFactors, [&](
unsigned factor) -> Attribute {
788 return IntegerAttr::get(IntegerType::get(context, 32), factor);
791 auto getDimensionsAttr =
792 [&](SmallVector<unsigned, 4> specifiedOrDefaultDims) -> ArrayAttr {
793 return ArrayAttr::get(
794 context, llvm::map_to_vector(specifiedOrDefaultDims,
795 [&](
unsigned dim) -> Attribute {
796 return IntegerAttr::get(
797 IntegerType::get(context, 32), dim);
803 operation->walk([&](affine::AffineParallelOp affineParallelOp) {
804 affineParallelOp.walk([&](Operation *op) {
805 if (!isa<affine::AffineWriteOpInterface, affine::AffineReadOpInterface>(
807 return WalkResult::advance();
809 auto read = dyn_cast<affine::AffineReadOpInterface>(op);
812 : cast<affine::AffineWriteOpInterface>(op).getMemRef();
813 MemRefType memrefType =
814 read ? read.getMemRefType()
815 : cast<affine::AffineWriteOpInterface>(op).getMemRefType();
817 if (
auto *originalDef = memref.getDefiningOp()) {
819 if (!originalDef->getAttr(bankingFactorsStr)) {
820 originalDef->setAttr(bankingFactorsStr, defaultFactorsAttr);
825 if (!originalDef->getAttr(bankingDimensionsStr)) {
826 SmallVector<unsigned, 4> specifiedOrDefaultDims =
828 memrefType.getRank(),
829 memrefType.getShape());
831 originalDef->setAttr(bankingDimensionsStr,
832 getDimensionsAttr(specifiedOrDefaultDims));
836 originalDef->getAttr(bankingDimensionsStr));
837 }
else if (isa<BlockArgument>(memref)) {
838 auto blockArg = cast<BlockArgument>(memref);
839 auto *parentOp = blockArg.getOwner()->getParentOp();
840 auto funcOp = dyn_cast<func::FuncOp>(parentOp);
842 "Expected the original memory to be a FuncOp block argument!");
843 unsigned argIndex = blockArg.getArgNumber();
844 SmallVector<unsigned, 4> specifiedOrDefaultDims =
846 bankingDimensions, memrefType.getRank(), memrefType.getShape());
848 if (!funcOp.getArgAttr(argIndex, bankingFactorsStr))
849 funcOp.setArgAttr(argIndex, bankingFactorsStr, defaultFactorsAttr);
850 if (!funcOp.getArgAttr(argIndex, bankingDimensionsStr))
851 funcOp.setArgAttr(argIndex, bankingDimensionsStr,
852 getDimensionsAttr(specifiedOrDefaultDims));
855 funcOp.getArgAttr(argIndex, bankingFactorsStr),
856 funcOp.getArgAttr(argIndex, bankingDimensionsStr));
858 return WalkResult::advance();
863void MemoryBankingPass::runOnOperation() {
864 this->bankingFactors = {bankingFactorsList.begin(), bankingFactorsList.end()};
865 this->bankingDimensions = {bankingDimensionsList.begin(),
866 bankingDimensionsList.end()};
868 if (getOperation().isExternal() ||
869 (bankingFactors.empty() ||
870 std::all_of(bankingFactors.begin(), bankingFactors.end(),
871 [](
unsigned f) { return f == 1; })))
874 if (std::any_of(bankingFactors.begin(), bankingFactors.end(),
875 [](
int f) { return f == 0; })) {
876 getOperation().emitError(
"banking factor must be greater than 1");
881 if (bankingDimensions.size() > bankingFactors.size()) {
882 getOperation().emitError(
883 "A banking dimension must be paired with a factor");
890 setAllBankingAttributes(getOperation(), &getContext());
892 OpBuilder builder(getOperation());
897 bool banksCreated =
false;
899 memoryToBanks.clear();
900 oldMemRefVals.clear();
903 banksCreated =
false;
904 getOperation().walk([&](mlir::affine::AffineParallelOp parOp) {
908 for (
auto memrefVal : memrefsInPar) {
910 if (!currConfig.factors) {
913 auto [it, inserted] = memoryToBanks.insert(
914 std::make_pair(memrefVal, SmallVector<Value>{}));
916 it->second = createBanks(builder, memrefVal);
921 if (failed(applyMemoryBanking(getOperation(), &getContext()))) {
925 }
while (banksCreated);
928LogicalResult MemoryBankingPass::applyMemoryBanking(Operation *operation,
932 DenseSet<Operation *> processedOps;
935 processedOps, oldMemRefVals);
938 GreedyRewriteConfig config;
939 config.strictMode = GreedyRewriteStrictness::ExistingOps;
940 if (failed(applyPatternsGreedily(operation, std::move(
patterns), config))) {
953std::unique_ptr<mlir::Pass>
955 ArrayRef<unsigned> bankingDimensions) {
956 return std::make_unique<MemoryBankingPass>(bankingFactors, bankingDimensions);
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
BankingConfigAttributes getMemRefBankingConfig(Value originalMem)
void verifyBankingAttributesSize(Attribute bankingFactorsAttr, Attribute bankingDimensionsAttr)
void verifyBankingConfigurations(unsigned bankingFactor, unsigned bankingDimension, MemRefType originalType)
unsigned getCurrBankingInfo(BankingConfigAttributes bankingConfigAttrs, StringRef attrName)
DenseSet< Value > collectMemRefs(affine::AffineParallelOp affineParallelOp)
Attribute getRemainingBankingInfo(MLIRContext *context, BankingConfigAttributes bankingConfigAttrs, StringRef attrName)
SmallVector< int64_t > decodeIndex(int64_t linIndex, ArrayRef< int64_t > shape)
LogicalResult cleanUpOldMemRefs(DenseSet< Value > &oldMemRefVals, DenseSet< Operation * > &opsToErase)
SmallVector< Value, 4 > handleGetGlobalOp(memref::GetGlobalOp getGlobalOp, uint64_t bankingFactor, unsigned bankingDimension, MemRefType newMemRefType, OpBuilder &builder, DictionaryAttr remainingAttrs)
void updateFuncOpArgumentTypes(func::FuncOp funcOp, unsigned argIndex, MemRefType newMemRefType, unsigned numInsertedArgs)
MemRefType computeBankedMemRefType(MemRefType originalType, uint64_t bankingFactor, unsigned bankingDimension)
SmallVector< unsigned, 4 > getSpecifiedOrDefaultBankingDim(const ArrayRef< unsigned > bankingDimensions, int64_t rank, ArrayRef< int64_t > shape)
void updateFuncOpArgAttrs(func::FuncOp funcOp, unsigned argIndex, unsigned numInsertedArgs, DictionaryAttr remainingAttrs)
SmallVector< SmallVector< Attribute > > sliceSubBlock(ArrayRef< Attribute > allAttrs, ArrayRef< int64_t > memShape, unsigned bankingDimension, unsigned bankingFactor)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createMemoryBankingPass(ArrayRef< unsigned > bankingFactors={}, ArrayRef< unsigned > bankingDimensions={})
DenseMap< Value, SmallVector< Value > > & memoryToBanks
LogicalResult matchAndRewrite(mlir::affine::AffineLoadOp loadOp, PatternRewriter &rewriter) const override
BankAffineLoadPattern(MLIRContext *context, DenseMap< Value, SmallVector< Value > > &memoryToBanks, DenseSet< Value > &oldMemRefVals)
DenseSet< Value > & oldMemRefVals
DenseMap< Value, SmallVector< Value > > & memoryToBanks
DenseSet< Value > & oldMemRefVals
DenseSet< Operation * > & processedOps
BankAffineStorePattern(MLIRContext *context, DenseMap< Value, SmallVector< Value > > &memoryToBanks, DenseSet< Operation * > &opsToErase, DenseSet< Operation * > &processedOps, DenseSet< Value > &oldMemRefVals)
DenseSet< Operation * > & opsToErase
LogicalResult matchAndRewrite(mlir::affine::AffineStoreOp storeOp, PatternRewriter &rewriter) const override
DenseMap< Value, SmallVector< Value > > & memoryToBanks
BankReturnPattern(MLIRContext *context, DenseMap< Value, SmallVector< Value > > &memoryToBanks)
LogicalResult matchAndRewrite(func::ReturnOp returnOp, PatternRewriter &rewriter) const override