15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/SCF/IR/SCF.h"
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/OperationSupport.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Support/LLVM.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/raw_ostream.h"
30 #define GEN_PASS_DEF_MEMORYBANKING
31 #include "circt/Transforms/Passes.h.inc"
35 using namespace circt;
41 struct MemoryBankingPass
42 :
public circt::impl::MemoryBankingBase<MemoryBankingPass> {
43 MemoryBankingPass(
const MemoryBankingPass &other) =
default;
44 explicit MemoryBankingPass(
45 std::optional<unsigned> bankingFactor = std::nullopt) {}
47 void runOnOperation()
override;
51 DenseMap<Value, SmallVector<Value>> memoryToBanks;
52 DenseSet<Operation *> opsToErase;
58 DenseSet<Value> memrefVals;
59 parOp.walk([&](Operation *op) {
60 for (
auto operand : op->getOperands()) {
61 if (isa<MemRefType>(operand.getType()))
62 memrefVals.insert(operand);
64 return WalkResult::advance();
70 uint64_t bankingFactor) {
71 ArrayRef<int64_t> originalShape = originalType.getShape();
72 assert(!originalShape.empty() &&
"memref shape should not be empty");
73 assert(originalType.getRank() == 1 &&
74 "currently only support one dimension memories");
75 SmallVector<int64_t, 4> newShape(originalShape.begin(), originalShape.end());
76 assert(newShape.front() % bankingFactor == 0 &&
77 "memref shape must be evenly divided by the banking factor");
78 newShape.front() /= bankingFactor;
79 MemRefType newMemRefType =
81 originalType.getLayout(), originalType.getMemorySpace());
86 SmallVector<Value, 4>
createBanks(Value originalMem, uint64_t bankingFactor) {
87 MemRefType originalMemRefType = cast<MemRefType>(originalMem.getType());
88 MemRefType newMemRefType =
90 SmallVector<Value, 4> banks;
91 if (
auto blockArgMem = dyn_cast<BlockArgument>(originalMem)) {
92 Block *block = blockArgMem.getOwner();
93 unsigned blockArgNum = blockArgMem.getArgNumber();
95 SmallVector<Type> banksType;
96 for (
unsigned i = 0; i < bankingFactor; ++i) {
97 block->insertArgument(blockArgNum + 1 + i, newMemRefType,
98 blockArgMem.getLoc());
102 block->getArguments().slice(blockArgNum + 1, bankingFactor);
103 banks.append(blockArgs.begin(), blockArgs.end());
105 Operation *originalDef = originalMem.getDefiningOp();
106 Location loc = originalDef->getLoc();
107 OpBuilder builder(originalDef);
108 builder.setInsertionPointAfter(originalDef);
109 TypeSwitch<Operation *>(originalDef)
110 .Case<memref::AllocOp>([&](memref::AllocOp allocOp) {
111 for (uint64_t bankCnt = 0; bankCnt < bankingFactor; ++bankCnt) {
113 builder.create<memref::AllocOp>(loc, newMemRefType);
114 banks.push_back(bankAllocOp);
117 .Case<memref::AllocaOp>([&](memref::AllocaOp allocaOp) {
118 for (uint64_t bankCnt = 0; bankCnt < bankingFactor; ++bankCnt) {
120 builder.create<memref::AllocaOp>(loc, newMemRefType);
121 banks.push_back(bankAllocaOp);
124 .Default([](Operation *) {
125 llvm_unreachable(
"Unhandled memory operation type");
135 DenseMap<Value, SmallVector<Value>> &memoryToBanks)
137 bankingFactor(bankingFactor), memoryToBanks(memoryToBanks) {}
140 PatternRewriter &rewriter)
const override {
141 Location loc = loadOp.getLoc();
142 auto banks = memoryToBanks[loadOp.getMemref()];
143 Value loadIndex = loadOp.getIndices().front();
145 AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0) % bankingFactor});
147 1, 0, {rewriter.getAffineDimExpr(0).floorDiv(bankingFactor)});
149 Value bankIndex = rewriter.create<affine::AffineApplyOp>(
150 loc, modMap, loadIndex);
152 rewriter.create<affine::AffineApplyOp>(loc, divMap, loadIndex);
154 SmallVector<Type> resultTypes = {loadOp.getResult().getType()};
156 SmallVector<int64_t, 4> caseValues;
157 for (
unsigned i = 0; i < bankingFactor; ++i)
158 caseValues.push_back(i);
160 rewriter.setInsertionPoint(loadOp);
161 scf::IndexSwitchOp switchOp = rewriter.create<scf::IndexSwitchOp>(
162 loc, resultTypes, bankIndex, caseValues,
165 for (
unsigned i = 0; i < bankingFactor; ++i) {
166 Region &caseRegion = switchOp.getCaseRegions()[i];
167 rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
169 rewriter.create<mlir::affine::AffineLoadOp>(loc, banks[i], offset);
170 rewriter.create<scf::YieldOp>(loc, bankedLoad);
173 Region &defaultRegion = switchOp.getDefaultRegion();
174 assert(defaultRegion.empty() &&
"Default region should be empty");
175 rewriter.setInsertionPointToStart(&defaultRegion.emplaceBlock());
178 cast<TypedAttr>(rewriter.getZeroAttr(loadOp.getType()));
179 auto defaultValue = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
180 rewriter.create<scf::YieldOp>(loc, defaultValue.getResult());
182 rewriter.replaceOp(loadOp, switchOp.getResult(0));
196 DenseMap<Value, SmallVector<Value>> &memoryToBanks,
197 DenseSet<Operation *> &opsToErase,
198 DenseSet<Operation *> &processedOps)
200 bankingFactor(bankingFactor), memoryToBanks(memoryToBanks),
201 opsToErase(opsToErase), processedOps(processedOps) {}
204 PatternRewriter &rewriter)
const override {
205 if (processedOps.contains(storeOp)) {
208 Location loc = storeOp.getLoc();
209 auto banks = memoryToBanks[storeOp.getMemref()];
210 Value storeIndex = storeOp.getIndices().front();
213 AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0) % bankingFactor});
215 1, 0, {rewriter.getAffineDimExpr(0).floorDiv(bankingFactor)});
217 Value bankIndex = rewriter.create<affine::AffineApplyOp>(
218 loc, modMap, storeIndex);
220 rewriter.create<affine::AffineApplyOp>(loc, divMap, storeIndex);
222 SmallVector<Type> resultTypes = {};
224 SmallVector<int64_t, 4> caseValues;
225 for (
unsigned i = 0; i < bankingFactor; ++i)
226 caseValues.push_back(i);
228 rewriter.setInsertionPoint(storeOp);
229 scf::IndexSwitchOp switchOp = rewriter.create<scf::IndexSwitchOp>(
230 loc, resultTypes, bankIndex, caseValues,
233 for (
unsigned i = 0; i < bankingFactor; ++i) {
234 Region &caseRegion = switchOp.getCaseRegions()[i];
235 rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
236 rewriter.create<mlir::affine::AffineStoreOp>(
237 loc, storeOp.getValueToStore(), banks[i], offset);
238 rewriter.create<scf::YieldOp>(loc);
241 Region &defaultRegion = switchOp.getDefaultRegion();
242 assert(defaultRegion.empty() &&
"Default region should be empty");
243 rewriter.setInsertionPointToStart(&defaultRegion.emplaceBlock());
245 rewriter.create<scf::YieldOp>(loc);
247 processedOps.insert(storeOp);
248 opsToErase.insert(storeOp);
263 DenseMap<Value, SmallVector<Value>> &memoryToBanks)
265 memoryToBanks(memoryToBanks) {}
268 PatternRewriter &rewriter)
const override {
269 Location loc = returnOp.getLoc();
270 SmallVector<Value, 4> newReturnOperands;
271 bool allOrigMemsUsedByReturn =
true;
272 for (
auto operand : returnOp.getOperands()) {
273 if (!memoryToBanks.contains(operand)) {
274 newReturnOperands.push_back(operand);
277 if (operand.hasOneUse())
278 allOrigMemsUsedByReturn =
false;
279 auto banks = memoryToBanks[operand];
280 newReturnOperands.append(banks.begin(), banks.end());
283 func::FuncOp funcOp = returnOp.getParentOp();
284 rewriter.setInsertionPointToEnd(&funcOp.getBlocks().front());
286 rewriter.create<func::ReturnOp>(loc, ValueRange(newReturnOperands));
287 TypeRange newReturnType = TypeRange(newReturnOperands);
288 FunctionType newFuncType =
290 funcOp.getFunctionType().getInputs(), newReturnType);
291 funcOp.setType(newFuncType);
293 if (allOrigMemsUsedByReturn)
294 rewriter.replaceOp(returnOp, newReturnOp);
308 DenseSet<Operation *> &opsToErase) {
309 DenseSet<func::FuncOp> funcsToModify;
310 SmallVector<Value, 4> valuesToErase;
311 for (
auto &memrefVal : oldMemRefVals) {
312 valuesToErase.push_back(memrefVal);
313 if (
auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
315 dyn_cast<func::FuncOp>(blockArg.getOwner()->getParentOp()))
316 funcsToModify.insert(funcOp);
320 for (
auto *op : opsToErase) {
324 for (
auto &memrefVal : valuesToErase) {
325 assert(memrefVal.use_empty() &&
"use must be empty");
326 if (
auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
327 blockArg.getOwner()->eraseArgument(blockArg.getArgNumber());
328 }
else if (
auto *op = memrefVal.getDefiningOp()) {
334 for (
auto funcOp : funcsToModify) {
335 SmallVector<Type, 4> newArgTypes;
336 for (BlockArgument arg : funcOp.getArguments()) {
337 newArgTypes.push_back(arg.getType());
339 FunctionType newFuncType =
341 funcOp.getFunctionType().getResults());
342 funcOp.setType(newFuncType);
348 void MemoryBankingPass::runOnOperation() {
349 if (getOperation().isExternal() || bankingFactor == 1)
352 if (bankingFactor == 0) {
353 getOperation().emitError(
"banking factor must be greater than 1");
358 getOperation().walk([&](mlir::affine::AffineParallelOp parOp) {
361 for (
auto memrefVal : memrefsInPar)
362 memoryToBanks[memrefVal] =
createBanks(memrefVal, bankingFactor);
365 auto *ctx = &getContext();
368 DenseSet<Operation *> processedOps;
371 opsToErase, processedOps);
374 GreedyRewriteConfig config;
375 config.strictMode = GreedyRewriteStrictness::ExistingOps;
376 if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(
patterns),
382 DenseSet<Value> oldMemRefVals;
383 for (
const auto &[memory, _] : memoryToBanks)
384 oldMemRefVals.insert(memory);
392 std::unique_ptr<mlir::Pass>
394 return std::make_unique<MemoryBankingPass>(bankingFactor);
assert(baseType &&"element must be base type")
SmallVector< Value, 4 > createBanks(Value originalMem, uint64_t bankingFactor)
DenseSet< Value > collectMemRefs(mlir::affine::AffineParallelOp parOp)
LogicalResult cleanUpOldMemRefs(DenseSet< Value > &oldMemRefVals, DenseSet< Operation * > &opsToErase)
MemRefType computeBankedMemRefType(MemRefType originalType, uint64_t bankingFactor)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createMemoryBankingPass(std::optional< unsigned > bankingFactor=std::nullopt)
DenseMap< Value, SmallVector< Value > > & memoryToBanks
BankAffineLoadPattern(MLIRContext *context, uint64_t bankingFactor, DenseMap< Value, SmallVector< Value >> &memoryToBanks)
LogicalResult matchAndRewrite(mlir::affine::AffineLoadOp loadOp, PatternRewriter &rewriter) const override
DenseMap< Value, SmallVector< Value > > & memoryToBanks
DenseSet< Operation * > & processedOps
BankAffineStorePattern(MLIRContext *context, uint64_t bankingFactor, DenseMap< Value, SmallVector< Value >> &memoryToBanks, DenseSet< Operation * > &opsToErase, DenseSet< Operation * > &processedOps)
DenseSet< Operation * > & opsToErase
LogicalResult matchAndRewrite(mlir::affine::AffineStoreOp storeOp, PatternRewriter &rewriter) const override
DenseMap< Value, SmallVector< Value > > & memoryToBanks
LogicalResult matchAndRewrite(func::ReturnOp returnOp, PatternRewriter &rewriter) const override
BankReturnPattern(MLIRContext *context, DenseMap< Value, SmallVector< Value >> &memoryToBanks)