CIRCT 20.0.0git
Loading...
Searching...
No Matches
MemoryBanking.cpp
Go to the documentation of this file.
1//===- MemoryBanking.cpp - memory bank parallel loops -----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements parallel loop memory banking.
10//
11//===----------------------------------------------------------------------===//
12
13#include "circt/Support/LLVM.h"
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/SCF/IR/SCF.h"
19#include "mlir/IR/AffineExpr.h"
20#include "mlir/IR/AffineMap.h"
21#include "mlir/IR/OperationSupport.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Support/LLVM.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/FormatVariadic.h"
28#include <numeric>
29
30namespace circt {
31#define GEN_PASS_DEF_MEMORYBANKING
32#include "circt/Transforms/Passes.h.inc"
33} // namespace circt
34
35using namespace mlir;
36using namespace circt;
37
38namespace {
39
40/// Partition memories used in `affine.parallel` operation by the
41/// `bankingFactor` throughout the program.
42struct MemoryBankingPass
43 : public circt::impl::MemoryBankingBase<MemoryBankingPass> {
44 MemoryBankingPass(const MemoryBankingPass &other) = default;
45 explicit MemoryBankingPass(
46 std::optional<unsigned> bankingFactor = std::nullopt,
47 std::optional<int> bankingDimension = std::nullopt) {}
48
49 void runOnOperation() override;
50
51private:
52 // map from original memory definition to newly allocated banks
53 DenseMap<Value, SmallVector<Value>> memoryToBanks;
54 DenseSet<Operation *> opsToErase;
55 // Track memory references that need to be cleaned up after memory banking is
56 // complete.
57 DenseSet<Value> oldMemRefVals;
58};
59} // namespace
60
61// Collect all memref in the `parOp`'s region'
62DenseSet<Value> collectMemRefs(mlir::affine::AffineParallelOp parOp) {
63 DenseSet<Value> memrefVals;
64 parOp.walk([&](Operation *op) {
65 for (auto operand : op->getOperands()) {
66 if (isa<MemRefType>(operand.getType()))
67 memrefVals.insert(operand);
68 }
69 return WalkResult::advance();
70 });
71 return memrefVals;
72}
73
74// Verify the banking configuration with different conditions.
75void verifyBankingConfigurations(unsigned bankingDimension,
76 unsigned bankingFactor,
77 MemRefType originalType) {
78 ArrayRef<int64_t> originalShape = originalType.getShape();
79 assert(!originalShape.empty() && "memref shape should not be empty");
80 assert(bankingDimension < originalType.getRank() &&
81 "dimension must be within the memref rank");
82 assert(originalShape[bankingDimension] % bankingFactor == 0 &&
83 "memref shape must be evenly divided by the banking factor");
84}
85
86MemRefType computeBankedMemRefType(MemRefType originalType,
87 uint64_t bankingFactor,
88 unsigned bankingDimension) {
89 ArrayRef<int64_t> originalShape = originalType.getShape();
90 SmallVector<int64_t, 4> newShape(originalShape.begin(), originalShape.end());
91 newShape[bankingDimension] /= bankingFactor;
92 MemRefType newMemRefType =
93 MemRefType::get(newShape, originalType.getElementType(),
94 originalType.getLayout(), originalType.getMemorySpace());
95
96 return newMemRefType;
97}
98
99// Decodes the flat index `linIndex` into an n-dimensional index based on the
100// given `shape` of the array in row-major order. Returns an array to represent
101// the n-dimensional indices.
102SmallVector<int64_t> decodeIndex(int64_t linIndex, ArrayRef<int64_t> shape) {
103 const unsigned rank = shape.size();
104 SmallVector<int64_t> ndIndex(rank, 0);
105
106 // Compute from last dimension to first because we assume row-major.
107 for (int64_t d = rank - 1; d >= 0; --d) {
108 ndIndex[d] = linIndex % shape[d];
109 linIndex /= shape[d];
110 }
111
112 return ndIndex;
113}
114
115// Performs multi-dimensional slicing on `allAttrs` by extracting all elements
116// whose coordinates range from `bankCnt`*`bankingDimension` to
117// (`bankCnt`+1)*`bankingDimension` from `bankingDimension`'s dimension, leaving
118// other dimensions alone.
119SmallVector<SmallVector<Attribute>> sliceSubBlock(ArrayRef<Attribute> allAttrs,
120 ArrayRef<int64_t> memShape,
121 unsigned bankingDimension,
122 unsigned bankingFactor) {
123 size_t numElements = std::reduce(memShape.begin(), memShape.end(), 1,
124 std::multiplies<size_t>());
125 // `bankingFactor` number of flattened attributes that store the information
126 // in the original globalOp.
127 SmallVector<SmallVector<Attribute>> subBlocks;
128 subBlocks.resize(bankingFactor);
129
130 for (unsigned linIndex = 0; linIndex < numElements; ++linIndex) {
131 SmallVector<int64_t> ndIndex = decodeIndex(linIndex, memShape);
132 unsigned subBlockIndex = ndIndex[bankingDimension] % bankingFactor;
133 subBlocks[subBlockIndex].push_back(allAttrs[linIndex]);
134 }
135
136 return subBlocks;
137}
138
139// Handles the splitting of a GetGlobalOp into multiple banked memory and
140// creates new GetGlobalOp to represent each banked memory by slicing the data
141// in the original GetGlobalOp.
142SmallVector<Value, 4> handleGetGlobalOp(memref::GetGlobalOp getGlobalOp,
143 uint64_t bankingFactor,
144 unsigned bankingDimension,
145 MemRefType newMemRefType,
146 OpBuilder &builder) {
147 SmallVector<Value, 4> banks;
148 auto memTy = cast<MemRefType>(getGlobalOp.getType());
149 ArrayRef<int64_t> originalShape = memTy.getShape();
150 auto newShape =
151 SmallVector<int64_t>(originalShape.begin(), originalShape.end());
152 newShape[bankingDimension] = originalShape[bankingDimension] / bankingFactor;
153
154 auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
155 auto globalOpNameAttr = getGlobalOp.getNameAttr();
156 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
157 SymbolTable::lookupSymbolIn(symbolTableOp, globalOpNameAttr));
158 assert(globalOp && "The corresponding GlobalOp should exist in the module");
159 MemRefType globalOpTy = globalOp.getType();
160
161 auto cstAttr =
162 dyn_cast_or_null<DenseElementsAttr>(globalOp.getConstantInitValue());
163 auto attributes = cstAttr.getValues<Attribute>();
164 SmallVector<Attribute, 8> allAttrs(attributes.begin(), attributes.end());
165
166 auto subBlocks =
167 sliceSubBlock(allAttrs, originalShape, bankingDimension, bankingFactor);
168
169 // Initialize globalOp and getGlobalOp's insertion points. Since
170 // bankingFactor is guaranteed to be greater than zero as it would
171 // have early exited if not, the loop below will execute at least
172 // once. So it's safe to manipulate the insertion points here.
173 builder.setInsertionPointAfter(globalOp);
174 OpBuilder::InsertPoint globalOpsInsertPt = builder.saveInsertionPoint();
175 builder.setInsertionPointAfter(getGlobalOp);
176 OpBuilder::InsertPoint getGlobalOpsInsertPt = builder.saveInsertionPoint();
177
178 for (size_t bankCnt = 0; bankCnt < bankingFactor; ++bankCnt) {
179 // Prepare relevant information to create a new GlobalOp
180 auto newMemRefTy = MemRefType::get(newShape, globalOpTy.getElementType());
181 auto newTypeAttr = TypeAttr::get(newMemRefTy);
182 std::string newName = llvm::formatv(
183 "{0}_{1}_{2}", globalOpNameAttr.getValue(), "bank", bankCnt);
184 RankedTensorType tensorType =
185 RankedTensorType::get({newShape}, globalOpTy.getElementType());
186 auto newInitValue = DenseElementsAttr::get(tensorType, subBlocks[bankCnt]);
187
188 builder.restoreInsertionPoint(globalOpsInsertPt);
189 auto newGlobalOp = builder.create<memref::GlobalOp>(
190 globalOp.getLoc(), builder.getStringAttr(newName),
191 globalOp.getSymVisibilityAttr(), newTypeAttr, newInitValue,
192 globalOp.getConstantAttr(), globalOp.getAlignmentAttr());
193 builder.setInsertionPointAfter(newGlobalOp);
194 globalOpsInsertPt = builder.saveInsertionPoint();
195
196 builder.restoreInsertionPoint(getGlobalOpsInsertPt);
197 auto newGetGlobalOp = builder.create<memref::GetGlobalOp>(
198 getGlobalOp.getLoc(), newMemRefTy, newGlobalOp.getName());
199 builder.setInsertionPointAfter(newGetGlobalOp);
200 getGlobalOpsInsertPt = builder.saveInsertionPoint();
201
202 banks.push_back(newGetGlobalOp);
203 }
204
205 globalOp.erase();
206 return banks;
207}
208
209unsigned getSpecifiedOrDefaultBankingDim(std::optional<int> bankingDimensionOpt,
210 int64_t rank,
211 ArrayRef<int64_t> shape) {
212 // If the banking dimension is already specified, return it.
213 // Note, the banking dimension will always be nonempty because TableGen will
214 // assign it with a default value -1 if it's not specified by the user. Thus,
215 // -1 is the sentinel value to indicate the default behavior, which is the
216 // innermost dimension with shape greater than 1.
217 if (bankingDimensionOpt.has_value() && *bankingDimensionOpt >= 0) {
218 return static_cast<unsigned>(*bankingDimensionOpt);
219 }
220
221 // Otherwise, find the innermost dimension with size > 1.
222 // For example, [[1], [2], [3], [4]] with `bankingFactor`=2 will be banked to
223 // [[1], [3]] and [[2], [4]].
224 int bankingDimension = -1;
225 for (int dim = rank - 1; dim >= 0; --dim) {
226 if (shape[dim] > 1) {
227 bankingDimension = dim;
228 break;
229 }
230 }
231
232 assert(bankingDimension >= 0 && "No eligible dimension for banking");
233 return static_cast<unsigned>(bankingDimension);
234}
235
236// Retrieve potentially specified banking factor/dimension attributes and
237// overwrite the command line or the default ones.
238void resolveBankingAttributes(Value originalMem, unsigned &bankingFactor,
239 unsigned &bankingDimension) {
240 if (auto *originalDef = originalMem.getDefiningOp()) {
241 if (auto attrFactor = dyn_cast_if_present<IntegerAttr>(
242 originalDef->getAttr("banking.factor")))
243 bankingFactor = attrFactor.getInt();
244 if (auto attrDimension = dyn_cast_if_present<IntegerAttr>(
245 originalDef->getAttr("banking.dimension")))
246 bankingDimension = attrDimension.getInt();
247
248 return;
249 }
250
251 if (isa<BlockArgument>(originalMem)) {
252 auto blockArg = cast<BlockArgument>(originalMem);
253 auto *parentOp = blockArg.getOwner()->getParentOp();
254
255 auto funcOp = dyn_cast<func::FuncOp>(parentOp);
256 assert(funcOp &&
257 "Expected the original memory to be a FuncOp block argument!");
258
259 unsigned argIndex = blockArg.getArgNumber();
260 if (auto argAttrs = funcOp.getArgAttrDict(argIndex)) {
261 if (auto attrFactor =
262 dyn_cast_if_present<IntegerAttr>(argAttrs.get("banking.factor")))
263 bankingFactor = attrFactor.getInt();
264 if (auto attrDimension = dyn_cast_if_present<IntegerAttr>(
265 argAttrs.get("banking.dimension")))
266 bankingDimension = attrDimension.getInt();
267 }
268
269 return;
270 }
271}
272
273// Update the argument types of `funcOp` by inserting `numInsertedArgs` number
274// of `newMemRefType` after `argIndex`.
275void updateFuncOpArgumentTypes(func::FuncOp funcOp, unsigned argIndex,
276 MemRefType newMemRefType,
277 unsigned numInsertedArgs) {
278 auto originalArgTypes = funcOp.getFunctionType().getInputs();
279 SmallVector<Type, 4> updatedArgTypes;
280
281 // Rebuild the argument types, inserting new types for the newly added
282 // arguments
283 for (unsigned i = 0; i < originalArgTypes.size(); ++i) {
284 updatedArgTypes.push_back(originalArgTypes[i]);
285
286 // Insert new argument types after the specified argument index
287 if (i == argIndex) {
288 for (unsigned j = 0; j < numInsertedArgs; ++j) {
289 updatedArgTypes.push_back(newMemRefType);
290 }
291 }
292 }
293
294 // Update the function type with the new argument types
295 auto resultTypes = funcOp.getFunctionType().getResults();
296 auto newFuncType =
297 FunctionType::get(funcOp.getContext(), updatedArgTypes, resultTypes);
298 funcOp.setType(newFuncType);
299}
300
301// Update `funcOp`'s "arg_attrs" by inserting `numInsertedArgs` number of empty
302// DictionaryAttr after `argIndex`.
303void updateFuncOpArgAttrs(func::FuncOp funcOp, unsigned argIndex,
304 unsigned numInsertedArgs) {
305 ArrayAttr existingArgAttrs = funcOp->getAttrOfType<ArrayAttr>("arg_attrs");
306 SmallVector<Attribute, 4> updatedArgAttrs;
307 unsigned numArguments = funcOp.getNumArguments();
308 unsigned newNumArguments = numArguments + numInsertedArgs;
309 updatedArgAttrs.resize(newNumArguments);
310
311 // Copy existing attributes, adjusting for the new arguments
312 for (unsigned i = 0; i < numArguments; ++i) {
313 // Shift attributes for arguments after the inserted ones.
314 unsigned newIndex = (i > argIndex) ? i + numInsertedArgs : i;
315 updatedArgAttrs[newIndex] = existingArgAttrs
316 ? existingArgAttrs[i]
317 : DictionaryAttr::get(funcOp.getContext());
318 }
319
320 // Initialize new attributes for the inserted arguments as empty dictionaries
321 for (unsigned i = 0; i < numInsertedArgs; ++i) {
322 updatedArgAttrs[argIndex + 1 + i] =
323 DictionaryAttr::get(funcOp.getContext());
324 }
325
326 // Set the updated attributes.
327 funcOp->setAttr("arg_attrs",
328 ArrayAttr::get(funcOp.getContext(), updatedArgAttrs));
329}
330
331SmallVector<Value, 4> createBanks(Value originalMem, unsigned bankingFactor,
332 std::optional<int> bankingDimensionOpt) {
333 MemRefType originalMemRefType = cast<MemRefType>(originalMem.getType());
334 unsigned rank = originalMemRefType.getRank();
335 ArrayRef<int64_t> shape = originalMemRefType.getShape();
336
337 unsigned bankingDimension =
338 getSpecifiedOrDefaultBankingDim(bankingDimensionOpt, rank, shape);
339
340 resolveBankingAttributes(originalMem, bankingFactor, bankingDimension);
341
342 verifyBankingConfigurations(bankingDimension, bankingFactor,
343 originalMemRefType);
344
345 MemRefType newMemRefType = computeBankedMemRefType(
346 originalMemRefType, bankingFactor, bankingDimension);
347 SmallVector<Value, 4> banks;
348 if (auto blockArgMem = dyn_cast<BlockArgument>(originalMem)) {
349 Block *block = blockArgMem.getOwner();
350 unsigned blockArgNum = blockArgMem.getArgNumber();
351
352 for (unsigned i = 0; i < bankingFactor; ++i)
353 block->insertArgument(blockArgNum + 1 + i, newMemRefType,
354 blockArgMem.getLoc());
355
356 auto blockArgs =
357 block->getArguments().slice(blockArgNum + 1, bankingFactor);
358 banks.append(blockArgs.begin(), blockArgs.end());
359
360 auto *parentOp = block->getParentOp();
361 auto funcOp = dyn_cast<func::FuncOp>(parentOp);
362 assert(funcOp && "BlockArgument is not part of a FuncOp");
363 // Update the ArgumentTypes of `funcOp` so that we can correctly get
364 // `getArgAttrDict` when resolving banking attributes across the iterations
365 // of creating new banks.
366 updateFuncOpArgumentTypes(funcOp, blockArgNum, newMemRefType,
367 bankingFactor);
368 updateFuncOpArgAttrs(funcOp, blockArgNum, bankingFactor);
369 } else {
370 Operation *originalDef = originalMem.getDefiningOp();
371 Location loc = originalDef->getLoc();
372 OpBuilder builder(originalDef);
373 builder.setInsertionPointAfter(originalDef);
374 TypeSwitch<Operation *>(originalDef)
375 .Case<memref::AllocOp>([&](memref::AllocOp allocOp) {
376 for (uint64_t bankCnt = 0; bankCnt < bankingFactor; ++bankCnt) {
377 auto bankAllocOp =
378 builder.create<memref::AllocOp>(loc, newMemRefType);
379 banks.push_back(bankAllocOp);
380 }
381 })
382 .Case<memref::AllocaOp>([&](memref::AllocaOp allocaOp) {
383 for (uint64_t bankCnt = 0; bankCnt < bankingFactor; ++bankCnt) {
384 auto bankAllocaOp =
385 builder.create<memref::AllocaOp>(loc, newMemRefType);
386 banks.push_back(bankAllocaOp);
387 }
388 })
389 .Case<memref::GetGlobalOp>([&](memref::GetGlobalOp getGlobalOp) {
390 auto newBanks =
391 handleGetGlobalOp(getGlobalOp, bankingFactor, bankingDimension,
392 newMemRefType, builder);
393 banks.append(newBanks.begin(), newBanks.end());
394 })
395 .Default([](Operation *) {
396 llvm_unreachable("Unhandled memory operation type");
397 });
398 }
399 return banks;
400}
401
402// Replace the original load operations with newly created memory banks
404 : public OpRewritePattern<mlir::affine::AffineLoadOp> {
405 BankAffineLoadPattern(MLIRContext *context, uint64_t bankingFactor,
406 std::optional<int> bankingDimensionOpt,
407 DenseMap<Value, SmallVector<Value>> &memoryToBanks,
408 DenseSet<Value> &oldMemRefVals)
409 : OpRewritePattern<mlir::affine::AffineLoadOp>(context),
412
413 LogicalResult matchAndRewrite(mlir::affine::AffineLoadOp loadOp,
414 PatternRewriter &rewriter) const override {
415 Location loc = loadOp.getLoc();
416 auto originalMem = loadOp.getMemref();
417 auto banks = memoryToBanks[originalMem];
418 auto loadIndices = loadOp.getIndices();
419 MemRefType originalMemRefType = loadOp.getMemRefType();
420 int64_t memrefRank = originalMemRefType.getRank();
421 ArrayRef<int64_t> shape = originalMemRefType.getShape();
422
423 auto bankingDimension =
425
426 resolveBankingAttributes(originalMem, bankingFactor, bankingDimension);
427
429 originalMemRefType);
430
431 auto modMap = AffineMap::get(
432 /*dimCount=*/memrefRank, /*symbolCount=*/0,
433 {rewriter.getAffineDimExpr(bankingDimension) % bankingFactor});
434 auto divMap = AffineMap::get(
435 memrefRank, 0,
436 {rewriter.getAffineDimExpr(bankingDimension).floorDiv(bankingFactor)});
437
438 Value bankIndex =
439 rewriter.create<affine::AffineApplyOp>(loc, modMap, loadIndices);
440 Value offset =
441 rewriter.create<affine::AffineApplyOp>(loc, divMap, loadIndices);
442 SmallVector<Value, 4> newIndices(loadIndices.begin(), loadIndices.end());
443 newIndices[bankingDimension] = offset;
444
445 SmallVector<Type> resultTypes = {loadOp.getResult().getType()};
446
447 SmallVector<int64_t, 4> caseValues;
448 for (unsigned i = 0; i < bankingFactor; ++i)
449 caseValues.push_back(i);
450
451 rewriter.setInsertionPoint(loadOp);
452 scf::IndexSwitchOp switchOp = rewriter.create<scf::IndexSwitchOp>(
453 loc, resultTypes, bankIndex, caseValues,
454 /*numRegions=*/bankingFactor);
455
456 for (unsigned i = 0; i < bankingFactor; ++i) {
457 Region &caseRegion = switchOp.getCaseRegions()[i];
458 rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
459 Value bankedLoad = rewriter.create<mlir::affine::AffineLoadOp>(
460 loc, banks[i], newIndices);
461 rewriter.create<scf::YieldOp>(loc, bankedLoad);
462 }
463
464 Region &defaultRegion = switchOp.getDefaultRegion();
465 assert(defaultRegion.empty() && "Default region should be empty");
466 rewriter.setInsertionPointToStart(&defaultRegion.emplaceBlock());
467
468 TypedAttr zeroAttr =
469 cast<TypedAttr>(rewriter.getZeroAttr(loadOp.getType()));
470 auto defaultValue = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
471 rewriter.create<scf::YieldOp>(loc, defaultValue.getResult());
472
473 // We track Load's memory reference only if it is a block argument - this is
474 // the only case where the reference isn't replaced.
475 if (Value memRef = loadOp.getMemref(); isa<BlockArgument>(memRef))
476 oldMemRefVals.insert(memRef);
477 rewriter.replaceOp(loadOp, switchOp.getResult(0));
478
479 return success();
480 }
481
482private:
483 mutable unsigned bankingFactor;
484 mutable std::optional<int> bankingDimensionOpt;
485 DenseMap<Value, SmallVector<Value>> &memoryToBanks;
486 DenseSet<Value> &oldMemRefVals;
487};
488
489// Replace the original store operations with newly created memory banks
491 : public OpRewritePattern<mlir::affine::AffineStoreOp> {
492 BankAffineStorePattern(MLIRContext *context, uint64_t bankingFactor,
493 std::optional<int> bankingDimensionOpt,
494 DenseMap<Value, SmallVector<Value>> &memoryToBanks,
495 DenseSet<Operation *> &opsToErase,
496 DenseSet<Operation *> &processedOps,
497 DenseSet<Value> &oldMemRefVals)
498 : OpRewritePattern<mlir::affine::AffineStoreOp>(context),
502
503 LogicalResult matchAndRewrite(mlir::affine::AffineStoreOp storeOp,
504 PatternRewriter &rewriter) const override {
505 if (processedOps.contains(storeOp)) {
506 return failure();
507 }
508 Location loc = storeOp.getLoc();
509 auto originalMem = storeOp.getMemref();
510 auto banks = memoryToBanks[originalMem];
511 auto storeIndices = storeOp.getIndices();
512 auto originalMemRefType = storeOp.getMemRefType();
513 int64_t memrefRank = originalMemRefType.getRank();
514 ArrayRef<int64_t> shape = originalMemRefType.getShape();
515
516 auto bankingDimension =
518
519 resolveBankingAttributes(originalMem, bankingFactor, bankingDimension);
520
522 originalMemRefType);
523
524 auto modMap = AffineMap::get(
525 /*dimCount=*/memrefRank, /*symbolCount=*/0,
526 {rewriter.getAffineDimExpr(bankingDimension) % bankingFactor});
527 auto divMap = AffineMap::get(
528 memrefRank, 0,
529 {rewriter.getAffineDimExpr(bankingDimension).floorDiv(bankingFactor)});
530
531 Value bankIndex =
532 rewriter.create<affine::AffineApplyOp>(loc, modMap, storeIndices);
533 Value offset =
534 rewriter.create<affine::AffineApplyOp>(loc, divMap, storeIndices);
535 SmallVector<Value, 4> newIndices(storeIndices.begin(), storeIndices.end());
536 newIndices[bankingDimension] = offset;
537
538 SmallVector<Type> resultTypes = {};
539
540 SmallVector<int64_t, 4> caseValues;
541 for (unsigned i = 0; i < bankingFactor; ++i)
542 caseValues.push_back(i);
543
544 rewriter.setInsertionPoint(storeOp);
545 scf::IndexSwitchOp switchOp = rewriter.create<scf::IndexSwitchOp>(
546 loc, resultTypes, bankIndex, caseValues,
547 /*numRegions=*/bankingFactor);
548
549 for (unsigned i = 0; i < bankingFactor; ++i) {
550 Region &caseRegion = switchOp.getCaseRegions()[i];
551 rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
552 rewriter.create<mlir::affine::AffineStoreOp>(
553 loc, storeOp.getValueToStore(), banks[i], newIndices);
554 rewriter.create<scf::YieldOp>(loc);
555 }
556
557 Region &defaultRegion = switchOp.getDefaultRegion();
558 assert(defaultRegion.empty() && "Default region should be empty");
559 rewriter.setInsertionPointToStart(&defaultRegion.emplaceBlock());
560
561 rewriter.create<scf::YieldOp>(loc);
562
563 processedOps.insert(storeOp);
564 opsToErase.insert(storeOp);
565 oldMemRefVals.insert(storeOp.getMemref());
566
567 return success();
568 }
569
570private:
571 mutable unsigned bankingFactor;
572 mutable std::optional<int> bankingDimensionOpt;
573 DenseMap<Value, SmallVector<Value>> &memoryToBanks;
574 DenseSet<Operation *> &opsToErase;
575 DenseSet<Operation *> &processedOps;
576 DenseSet<Value> &oldMemRefVals;
577};
578
579// Replace the original return operation with newly created memory banks
580struct BankReturnPattern : public OpRewritePattern<func::ReturnOp> {
581 BankReturnPattern(MLIRContext *context,
582 DenseMap<Value, SmallVector<Value>> &memoryToBanks)
583 : OpRewritePattern<func::ReturnOp>(context),
585
586 LogicalResult matchAndRewrite(func::ReturnOp returnOp,
587 PatternRewriter &rewriter) const override {
588 Location loc = returnOp.getLoc();
589 SmallVector<Value, 4> newReturnOperands;
590 bool allOrigMemsUsedByReturn = true;
591 for (auto operand : returnOp.getOperands()) {
592 if (!memoryToBanks.contains(operand)) {
593 newReturnOperands.push_back(operand);
594 continue;
595 }
596 if (operand.hasOneUse())
597 allOrigMemsUsedByReturn = false;
598 auto banks = memoryToBanks[operand];
599 newReturnOperands.append(banks.begin(), banks.end());
600 }
601
602 func::FuncOp funcOp = returnOp.getParentOp();
603 rewriter.setInsertionPointToEnd(&funcOp.getBlocks().front());
604 auto newReturnOp =
605 rewriter.create<func::ReturnOp>(loc, ValueRange(newReturnOperands));
606 TypeRange newReturnType = TypeRange(newReturnOperands);
607 FunctionType newFuncType =
608 FunctionType::get(funcOp.getContext(),
609 funcOp.getFunctionType().getInputs(), newReturnType);
610 funcOp.setType(newFuncType);
611
612 if (allOrigMemsUsedByReturn)
613 rewriter.replaceOp(returnOp, newReturnOp);
614
615 return success();
616 }
617
618private:
619 DenseMap<Value, SmallVector<Value>> &memoryToBanks;
620};
621
622// Clean up the empty uses old memory values by either erasing the defining
623// operation or replace the block arguments with new ones that corresponds to
624// the newly created banks. Change the function signature if the old memory
625// values are used as function arguments and/or return values.
626LogicalResult cleanUpOldMemRefs(DenseSet<Value> &oldMemRefVals,
627 DenseSet<Operation *> &opsToErase) {
628 DenseSet<func::FuncOp> funcsToModify;
629 SmallVector<Value, 4> valuesToErase;
630 DenseMap<func::FuncOp, SmallVector<unsigned, 4>> erasedArgIndices;
631 for (auto &memrefVal : oldMemRefVals) {
632 valuesToErase.push_back(memrefVal);
633 if (auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
634 if (auto funcOp =
635 dyn_cast<func::FuncOp>(blockArg.getOwner()->getParentOp())) {
636 funcsToModify.insert(funcOp);
637 erasedArgIndices[funcOp].push_back(blockArg.getArgNumber());
638 }
639 }
640 }
641
642 for (auto *op : opsToErase) {
643 op->erase();
644 }
645 // Erase values safely.
646 for (auto &memrefVal : valuesToErase) {
647 assert(memrefVal.use_empty() && "use must be empty");
648 if (auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
649 blockArg.getOwner()->eraseArgument(blockArg.getArgNumber());
650 } else if (auto *op = memrefVal.getDefiningOp()) {
651 op->erase();
652 }
653 }
654
655 // Modify the function argument attributes and function type accordingly
656 for (auto funcOp : funcsToModify) {
657 ArrayAttr existingArgAttrs = funcOp->getAttrOfType<ArrayAttr>("arg_attrs");
658 if (existingArgAttrs) {
659 SmallVector<Attribute, 4> updatedArgAttrs;
660 auto erasedIndices = erasedArgIndices[funcOp];
661 DenseSet<unsigned> indicesToErase(erasedIndices.begin(),
662 erasedIndices.end());
663 for (unsigned i = 0; i < existingArgAttrs.size(); ++i) {
664 if (!indicesToErase.contains(i))
665 updatedArgAttrs.push_back(existingArgAttrs[i]);
666 }
667
668 funcOp->setAttr("arg_attrs",
669 ArrayAttr::get(funcOp.getContext(), updatedArgAttrs));
670 }
671
672 SmallVector<Type, 4> newArgTypes;
673 for (BlockArgument arg : funcOp.getArguments()) {
674 newArgTypes.push_back(arg.getType());
675 }
676 FunctionType newFuncType =
677 FunctionType::get(funcOp.getContext(), newArgTypes,
678 funcOp.getFunctionType().getResults());
679 funcOp.setType(newFuncType);
680 }
681
682 return success();
683}
684
685void MemoryBankingPass::runOnOperation() {
686 if (getOperation().isExternal() || bankingFactor == 1)
687 return;
688
689 if (bankingFactor == 0) {
690 getOperation().emitError("banking factor must be greater than 1");
691 signalPassFailure();
692 return;
693 }
694
695 getOperation().walk([&](mlir::affine::AffineParallelOp parOp) {
696 DenseSet<Value> memrefsInPar = collectMemRefs(parOp);
697
698 for (auto memrefVal : memrefsInPar) {
699 auto [it, inserted] =
700 memoryToBanks.insert(std::make_pair(memrefVal, SmallVector<Value>{}));
701 if (inserted)
702 it->second = createBanks(memrefVal, bankingFactor, bankingDimension);
703 }
704 });
705
706 auto *ctx = &getContext();
707 RewritePatternSet patterns(ctx);
708
709 DenseSet<Operation *> processedOps;
710 patterns.add<BankAffineLoadPattern>(ctx, bankingFactor, bankingDimension,
711 memoryToBanks, oldMemRefVals);
712 patterns.add<BankAffineStorePattern>(ctx, bankingFactor, bankingDimension,
713 memoryToBanks, opsToErase, processedOps,
714 oldMemRefVals);
715 patterns.add<BankReturnPattern>(ctx, memoryToBanks);
716
717 GreedyRewriteConfig config;
718 config.strictMode = GreedyRewriteStrictness::ExistingOps;
719 if (failed(
720 applyPatternsGreedily(getOperation(), std::move(patterns), config))) {
721 signalPassFailure();
722 }
723
724 // Clean up the old memref values
725 if (failed(cleanUpOldMemRefs(oldMemRefVals, opsToErase))) {
726 signalPassFailure();
727 }
728}
729
730namespace circt {
731std::unique_ptr<mlir::Pass>
732createMemoryBankingPass(std::optional<unsigned> bankingFactor,
733 std::optional<int> bankingDimension) {
734 return std::make_unique<MemoryBankingPass>(bankingFactor, bankingDimension);
735}
736} // namespace circt
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
Definition CHIRRTL.cpp:30
SmallVector< Value, 4 > handleGetGlobalOp(memref::GetGlobalOp getGlobalOp, uint64_t bankingFactor, unsigned bankingDimension, MemRefType newMemRefType, OpBuilder &builder)
DenseSet< Value > collectMemRefs(mlir::affine::AffineParallelOp parOp)
void verifyBankingConfigurations(unsigned bankingDimension, unsigned bankingFactor, MemRefType originalType)
SmallVector< Value, 4 > createBanks(Value originalMem, unsigned bankingFactor, std::optional< int > bankingDimensionOpt)
void updateFuncOpArgAttrs(func::FuncOp funcOp, unsigned argIndex, unsigned numInsertedArgs)
SmallVector< int64_t > decodeIndex(int64_t linIndex, ArrayRef< int64_t > shape)
LogicalResult cleanUpOldMemRefs(DenseSet< Value > &oldMemRefVals, DenseSet< Operation * > &opsToErase)
void updateFuncOpArgumentTypes(func::FuncOp funcOp, unsigned argIndex, MemRefType newMemRefType, unsigned numInsertedArgs)
unsigned getSpecifiedOrDefaultBankingDim(std::optional< int > bankingDimensionOpt, int64_t rank, ArrayRef< int64_t > shape)
MemRefType computeBankedMemRefType(MemRefType originalType, uint64_t bankingFactor, unsigned bankingDimension)
SmallVector< SmallVector< Attribute > > sliceSubBlock(ArrayRef< Attribute > allAttrs, ArrayRef< int64_t > memShape, unsigned bankingDimension, unsigned bankingFactor)
void resolveBankingAttributes(Value originalMem, unsigned &bankingFactor, unsigned &bankingDimension)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createMemoryBankingPass(std::optional< unsigned > bankingFactor=std::nullopt, std::optional< int > bankingDimension=std::nullopt)
DenseMap< Value, SmallVector< Value > > & memoryToBanks
std::optional< int > bankingDimensionOpt
LogicalResult matchAndRewrite(mlir::affine::AffineLoadOp loadOp, PatternRewriter &rewriter) const override
BankAffineLoadPattern(MLIRContext *context, uint64_t bankingFactor, std::optional< int > bankingDimensionOpt, DenseMap< Value, SmallVector< Value > > &memoryToBanks, DenseSet< Value > &oldMemRefVals)
DenseSet< Value > & oldMemRefVals
BankAffineStorePattern(MLIRContext *context, uint64_t bankingFactor, std::optional< int > bankingDimensionOpt, DenseMap< Value, SmallVector< Value > > &memoryToBanks, DenseSet< Operation * > &opsToErase, DenseSet< Operation * > &processedOps, DenseSet< Value > &oldMemRefVals)
DenseMap< Value, SmallVector< Value > > & memoryToBanks
DenseSet< Value > & oldMemRefVals
DenseSet< Operation * > & processedOps
DenseSet< Operation * > & opsToErase
LogicalResult matchAndRewrite(mlir::affine::AffineStoreOp storeOp, PatternRewriter &rewriter) const override
std::optional< int > bankingDimensionOpt
DenseMap< Value, SmallVector< Value > > & memoryToBanks
BankReturnPattern(MLIRContext *context, DenseMap< Value, SmallVector< Value > > &memoryToBanks)
LogicalResult matchAndRewrite(func::ReturnOp returnOp, PatternRewriter &rewriter) const override