Loading [MathJax]/extensions/tex2jax.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
MemoryBanking.cpp
Go to the documentation of this file.
1//===- MemoryBanking.cpp - memory bank parallel loops -----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements parallel loop memory banking.
10//
11//===----------------------------------------------------------------------===//
12
13#include "circt/Support/LLVM.h"
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/SCF/IR/SCF.h"
19#include "mlir/IR/AffineExpr.h"
20#include "mlir/IR/AffineMap.h"
21#include "mlir/IR/OperationSupport.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Support/LLVM.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/FormatVariadic.h"
28#include <numeric>
29
30namespace circt {
31#define GEN_PASS_DEF_MEMORYBANKING
32#include "circt/Transforms/Passes.h.inc"
33} // namespace circt
34
35using namespace mlir;
36using namespace circt;
37
38namespace {
39struct BankingConfigAttributes {
40 Attribute factors;
41 Attribute dimensions;
42};
43
44constexpr std::string_view bankingFactorsStr = "banking.factors";
45constexpr std::string_view bankingDimensionsStr = "banking.dimensions";
46
47/// Partition memories used in `affine.parallel` operation by the
48/// `bankingFactor` throughout the program.
49struct MemoryBankingPass
50 : public circt::impl::MemoryBankingBase<MemoryBankingPass> {
51 MemoryBankingPass(const MemoryBankingPass &other) = default;
52 explicit MemoryBankingPass(ArrayRef<unsigned> bankingFactors = {},
53 ArrayRef<unsigned> bankingDimensions = {}) {}
54
55 void runOnOperation() override;
56
57 LogicalResult applyMemoryBanking(Operation *, MLIRContext *);
58
59 SmallVector<Value, 4> createBanks(OpBuilder &builder, Value originalMem);
60
61 void setAllBankingAttributes(Operation *, MLIRContext *);
62
63private:
64 SmallVector<unsigned, 4> bankingFactors;
65 SmallVector<unsigned, 4> bankingDimensions;
66 // map from original memory definition to newly allocated banks
67 DenseMap<Value, SmallVector<Value>> memoryToBanks;
68 DenseSet<Operation *> opsToErase;
69 // Track memory references that need to be cleaned up after memory banking is
70 // complete.
71 DenseSet<Value> oldMemRefVals;
72};
73} // namespace
74
75BankingConfigAttributes getMemRefBankingConfig(Value originalMem) {
76 Attribute bankingFactorsAttr, bankingDimensionsAttr;
77 if (auto blockArg = dyn_cast<BlockArgument>(originalMem)) {
78 Block *block = blockArg.getOwner();
79
80 auto *parentOp = block->getParentOp();
81 auto funcOp = dyn_cast<func::FuncOp>(parentOp);
82 assert(funcOp &&
83 "Expected the original memory to be a FuncOp block argument!");
84 unsigned argIndex = blockArg.getArgNumber();
85 if (auto argAttrs = funcOp.getArgAttrDict(argIndex)) {
86 bankingFactorsAttr = argAttrs.get(bankingFactorsStr);
87 bankingDimensionsAttr = argAttrs.get(bankingDimensionsStr);
88 }
89 } else {
90 Operation *originalDef = originalMem.getDefiningOp();
91 bankingFactorsAttr = originalDef->getAttr(bankingFactorsStr);
92 bankingDimensionsAttr = originalDef->getAttr(bankingDimensionsStr);
93 }
94 return BankingConfigAttributes{bankingFactorsAttr, bankingDimensionsAttr};
95}
96
97// Collect all memref in the `parOp`'s region'
98DenseSet<Value> collectMemRefs(affine::AffineParallelOp affineParallelOp) {
99 DenseSet<Value> memrefVals;
100 affineParallelOp.walk([&](Operation *op) {
101 if (!isa<affine::AffineWriteOpInterface>(op) &&
102 !isa<affine::AffineReadOpInterface>(op))
103 return WalkResult::advance();
104
105 auto read = dyn_cast<affine::AffineReadOpInterface>(op);
106 Value memref = read ? read.getMemRef()
107 : cast<affine::AffineWriteOpInterface>(op).getMemRef();
108 memrefVals.insert(memref);
109 return WalkResult::advance();
110 });
111 return memrefVals;
112}
113
114// Verify the banking configuration with different conditions.
115void verifyBankingConfigurations(unsigned bankingFactor,
116 unsigned bankingDimension,
117 MemRefType originalType) {
118 ArrayRef<int64_t> originalShape = originalType.getShape();
119 assert(!originalShape.empty() && "memref shape should not be empty");
120 assert(bankingDimension < originalType.getRank() &&
121 "dimension must be within the memref rank");
122 assert(originalShape[bankingDimension] % bankingFactor == 0 &&
123 "memref shape must be evenly divided by the banking factor");
124}
125
126MemRefType computeBankedMemRefType(MemRefType originalType,
127 uint64_t bankingFactor,
128 unsigned bankingDimension) {
129 ArrayRef<int64_t> originalShape = originalType.getShape();
130 SmallVector<int64_t, 4> newShape(originalShape.begin(), originalShape.end());
131 newShape[bankingDimension] /= bankingFactor;
132 MemRefType newMemRefType =
133 MemRefType::get(newShape, originalType.getElementType(),
134 originalType.getLayout(), originalType.getMemorySpace());
135
136 return newMemRefType;
137}
138
139// Decodes the flat index `linIndex` into an n-dimensional index based on the
140// given `shape` of the array in row-major order. Returns an array to represent
141// the n-dimensional indices.
142SmallVector<int64_t> decodeIndex(int64_t linIndex, ArrayRef<int64_t> shape) {
143 const unsigned rank = shape.size();
144 SmallVector<int64_t> ndIndex(rank, 0);
145
146 // Compute from last dimension to first because we assume row-major.
147 for (int64_t d = rank - 1; d >= 0; --d) {
148 ndIndex[d] = linIndex % shape[d];
149 linIndex /= shape[d];
150 }
151
152 return ndIndex;
153}
154
155// Performs multi-dimensional slicing on `allAttrs` by extracting all elements
156// whose coordinates range from `bankCnt`*`bankingDimension` to
157// (`bankCnt`+1)*`bankingDimension` from `bankingDimension`'s dimension, leaving
158// other dimensions alone.
159SmallVector<SmallVector<Attribute>> sliceSubBlock(ArrayRef<Attribute> allAttrs,
160 ArrayRef<int64_t> memShape,
161 unsigned bankingDimension,
162 unsigned bankingFactor) {
163 size_t numElements = std::reduce(memShape.begin(), memShape.end(), 1,
164 std::multiplies<size_t>());
165 // `bankingFactor` number of flattened attributes that store the information
166 // in the original globalOp.
167 SmallVector<SmallVector<Attribute>> subBlocks;
168 subBlocks.resize(bankingFactor);
169
170 for (unsigned linIndex = 0; linIndex < numElements; ++linIndex) {
171 SmallVector<int64_t> ndIndex = decodeIndex(linIndex, memShape);
172 unsigned subBlockIndex = ndIndex[bankingDimension] % bankingFactor;
173 subBlocks[subBlockIndex].push_back(allAttrs[linIndex]);
174 }
175
176 return subBlocks;
177}
178
179// Handles the splitting of a GetGlobalOp into multiple banked memory and
180// creates new GetGlobalOp to represent each banked memory by slicing the data
181// in the original GetGlobalOp.
182SmallVector<Value, 4>
183handleGetGlobalOp(memref::GetGlobalOp getGlobalOp, uint64_t bankingFactor,
184 unsigned bankingDimension, MemRefType newMemRefType,
185 OpBuilder &builder, DictionaryAttr remainingAttrs) {
186 SmallVector<Value, 4> banks;
187 auto memTy = cast<MemRefType>(getGlobalOp.getType());
188 ArrayRef<int64_t> originalShape = memTy.getShape();
189 auto newShape =
190 SmallVector<int64_t>(originalShape.begin(), originalShape.end());
191 newShape[bankingDimension] = originalShape[bankingDimension] / bankingFactor;
192
193 auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
194 auto globalOpNameAttr = getGlobalOp.getNameAttr();
195 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
196 SymbolTable::lookupSymbolIn(symbolTableOp, globalOpNameAttr));
197 assert(globalOp && "The corresponding GlobalOp should exist in the module");
198 MemRefType globalOpTy = globalOp.getType();
199
200 auto cstAttr =
201 dyn_cast_or_null<DenseElementsAttr>(globalOp.getConstantInitValue());
202 auto attributes = cstAttr.getValues<Attribute>();
203 SmallVector<Attribute, 8> allAttrs(attributes.begin(), attributes.end());
204
205 auto subBlocks =
206 sliceSubBlock(allAttrs, originalShape, bankingDimension, bankingFactor);
207
208 // Initialize globalOp and getGlobalOp's insertion points. Since
209 // bankingFactor is guaranteed to be greater than zero as it would
210 // have early exited if not, the loop below will execute at least
211 // once. So it's safe to manipulate the insertion points here.
212 builder.setInsertionPointAfter(globalOp);
213 OpBuilder::InsertPoint globalOpsInsertPt = builder.saveInsertionPoint();
214 builder.setInsertionPointAfter(getGlobalOp);
215 OpBuilder::InsertPoint getGlobalOpsInsertPt = builder.saveInsertionPoint();
216
217 for (size_t bankCnt = 0; bankCnt < bankingFactor; ++bankCnt) {
218 // Prepare relevant information to create a new GlobalOp
219 auto newMemRefTy = MemRefType::get(newShape, globalOpTy.getElementType());
220 auto newTypeAttr = TypeAttr::get(newMemRefTy);
221 std::string newName = llvm::formatv(
222 "{0}_{1}_{2}", globalOpNameAttr.getValue(), "bank", bankCnt);
223 RankedTensorType tensorType =
224 RankedTensorType::get({newShape}, globalOpTy.getElementType());
225 auto newInitValue = DenseElementsAttr::get(tensorType, subBlocks[bankCnt]);
226
227 builder.restoreInsertionPoint(globalOpsInsertPt);
228 auto newGlobalOp = builder.create<memref::GlobalOp>(
229 globalOp.getLoc(), builder.getStringAttr(newName),
230 globalOp.getSymVisibilityAttr(), newTypeAttr, newInitValue,
231 globalOp.getConstantAttr(), globalOp.getAlignmentAttr());
232 builder.setInsertionPointAfter(newGlobalOp);
233 globalOpsInsertPt = builder.saveInsertionPoint();
234
235 builder.restoreInsertionPoint(getGlobalOpsInsertPt);
236 auto newGetGlobalOp = builder.create<memref::GetGlobalOp>(
237 getGlobalOp.getLoc(), newMemRefTy, newGlobalOp.getName());
238 newGetGlobalOp->setAttrs(remainingAttrs);
239 builder.setInsertionPointAfter(newGetGlobalOp);
240 getGlobalOpsInsertPt = builder.saveInsertionPoint();
241
242 banks.push_back(newGetGlobalOp);
243 }
244
245 globalOp.erase();
246 return banks;
247}
248
249SmallVector<unsigned, 4>
250getSpecifiedOrDefaultBankingDim(const ArrayRef<unsigned> bankingDimensions,
251 int64_t rank, ArrayRef<int64_t> shape) {
252 // If the banking dimension is already specified, return it.
253 if (!bankingDimensions.empty()) {
254 return SmallVector<unsigned, 4>(bankingDimensions.begin(),
255 bankingDimensions.end());
256 }
257
258 // Otherwise, find the innermost dimension with size > 1.
259 // For example, [[1], [2], [3], [4]] with `bankingFactor`=2 will be banked to
260 // [[1], [3]] and [[2], [4]].
261 int bankingDimension = -1;
262 for (int dim = rank - 1; dim >= 0; --dim) {
263 if (shape[dim] > 1) {
264 bankingDimension = dim;
265 break;
266 }
267 }
268
269 assert(bankingDimension >= 0 && "No eligible dimension for banking");
270 return SmallVector<unsigned, 4>{static_cast<unsigned>(bankingDimension)};
271}
272
273// Update the argument types of `funcOp` by inserting `numInsertedArgs` number
274// of `newMemRefType` after `argIndex`.
275void updateFuncOpArgumentTypes(func::FuncOp funcOp, unsigned argIndex,
276 MemRefType newMemRefType,
277 unsigned numInsertedArgs) {
278 auto originalArgTypes = funcOp.getFunctionType().getInputs();
279 SmallVector<Type, 4> updatedArgTypes;
280
281 // Rebuild the argument types, inserting new types for the newly added
282 // arguments
283 for (unsigned i = 0; i < originalArgTypes.size(); ++i) {
284 updatedArgTypes.push_back(originalArgTypes[i]);
285
286 // Insert new argument types after the specified argument index
287 if (i == argIndex) {
288 for (unsigned j = 0; j < numInsertedArgs; ++j) {
289 updatedArgTypes.push_back(newMemRefType);
290 }
291 }
292 }
293
294 // Update the function type with the new argument types
295 auto resultTypes = funcOp.getFunctionType().getResults();
296 auto newFuncType =
297 FunctionType::get(funcOp.getContext(), updatedArgTypes, resultTypes);
298 funcOp.setType(newFuncType);
299}
300
301// Update `funcOp`'s "arg_attrs" by inserting `numInsertedArgs` number of
302// `remainingAttrs` after `argIndex`.
303void updateFuncOpArgAttrs(func::FuncOp funcOp, unsigned argIndex,
304 unsigned numInsertedArgs,
305 DictionaryAttr remainingAttrs) {
306 ArrayAttr existingArgAttrs = funcOp->getAttrOfType<ArrayAttr>("arg_attrs");
307 SmallVector<Attribute, 4> updatedArgAttrs;
308 unsigned numArguments = funcOp.getNumArguments();
309 unsigned newNumArguments = numArguments + numInsertedArgs;
310 updatedArgAttrs.resize(newNumArguments);
311
312 // Copy existing attributes, adjusting for the new arguments
313 for (unsigned i = 0; i < numArguments; ++i) {
314 // Shift attributes for arguments after the inserted ones.
315 unsigned newIndex = (i > argIndex) ? i + numInsertedArgs : i;
316 updatedArgAttrs[newIndex] = existingArgAttrs
317 ? existingArgAttrs[i]
318 : DictionaryAttr::get(funcOp.getContext());
319 }
320
321 // Initialize new attributes for the inserted arguments as empty dictionaries
322 for (unsigned i = 0; i < numInsertedArgs; ++i) {
323 updatedArgAttrs[argIndex + 1 + i] = remainingAttrs;
324 }
325
326 // Set the updated attributes.
327 funcOp->setAttr("arg_attrs",
328 ArrayAttr::get(funcOp.getContext(), updatedArgAttrs));
329}
330
331unsigned getCurrBankingInfo(BankingConfigAttributes bankingConfigAttrs,
332 StringRef attrName) {
333 auto getFirstInteger = [](Attribute attr) -> unsigned {
334 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
335 assert(!arrayAttr.empty() &&
336 "BankingConfig ArrayAttr should not be empty");
337 auto intAttr = dyn_cast<IntegerAttr>(arrayAttr.getValue().front());
338 assert(intAttr && "BankingConfig elements must be integers");
339 return intAttr.getInt();
340 }
341 auto intAttr = dyn_cast<IntegerAttr>(attr);
342 assert(intAttr && "BankingConfig attribute must be an integer");
343 return intAttr.getInt();
344 };
345
346 if (attrName.str() == bankingFactorsStr) {
347 return getFirstInteger(bankingConfigAttrs.factors);
348 }
349
350 assert(attrName.str() == bankingDimensionsStr &&
351 "BankingConfig only contains 'factors' and 'dimensions' attributes");
352 return getFirstInteger(bankingConfigAttrs.dimensions);
353}
354
355Attribute getRemainingBankingInfo(MLIRContext *context,
356 BankingConfigAttributes bankingConfigAttrs,
357 StringRef attrName) {
358 auto getRemainingElements = [context](Attribute attr) -> Attribute {
359 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
360 assert(!arrayAttr.empty() &&
361 "BankingConfig ArrayAttr should not be empty");
362 return arrayAttr.size() > 1
363 ? ArrayAttr::get(context, arrayAttr.getValue().take_back(
364 arrayAttr.size() - 1))
365 : nullptr;
366 }
367 assert(dyn_cast<IntegerAttr>(attr) &&
368 "BankingConfig attribute must be an integer");
369 return nullptr;
370 };
371
372 if (attrName.str() == bankingFactorsStr) {
373 return getRemainingElements(bankingConfigAttrs.factors);
374 }
375
376 assert(attrName.str() == bankingDimensionsStr &&
377 "BankingConfig only contains 'factors' and 'dimensions' attributes");
378 return getRemainingElements(bankingConfigAttrs.dimensions);
379}
380
381SmallVector<Value, 4> MemoryBankingPass::createBanks(OpBuilder &builder,
382 Value originalMem) {
383 MemRefType originalMemRefType = cast<MemRefType>(originalMem.getType());
384
385 MLIRContext *context = builder.getContext();
386
387 BankingConfigAttributes currBankingConfig =
388 getMemRefBankingConfig(originalMem);
389
390 unsigned currFactor =
391 getCurrBankingInfo(currBankingConfig, bankingFactorsStr);
392 unsigned currDimension =
393 getCurrBankingInfo(currBankingConfig, bankingDimensionsStr);
394
395 verifyBankingConfigurations(currFactor, currDimension, originalMemRefType);
396
397 Attribute remainingFactors =
398 getRemainingBankingInfo(context, currBankingConfig, bankingFactorsStr);
399 Attribute remainingDimensions =
400 getRemainingBankingInfo(context, currBankingConfig, bankingDimensionsStr);
401 DictionaryAttr remainingAttrs =
402 remainingFactors
403 ? DictionaryAttr::get(
404 context,
405 {builder.getNamedAttr(bankingFactorsStr, remainingFactors),
406 builder.getNamedAttr(bankingDimensionsStr,
407 remainingDimensions)})
408 : DictionaryAttr::get(context);
409
410 MemRefType newMemRefType =
411 computeBankedMemRefType(originalMemRefType, currFactor, currDimension);
412 SmallVector<Value, 4> banks;
413 if (auto blockArgMem = dyn_cast<BlockArgument>(originalMem)) {
414 Block *block = blockArgMem.getOwner();
415 unsigned blockArgNum = blockArgMem.getArgNumber();
416
417 for (unsigned i = 0; i < currFactor; ++i)
418 block->insertArgument(blockArgNum + 1 + i, newMemRefType,
419 blockArgMem.getLoc());
420
421 auto blockArgs = block->getArguments().slice(blockArgNum + 1, currFactor);
422 banks.append(blockArgs.begin(), blockArgs.end());
423
424 auto *parentOp = block->getParentOp();
425 auto funcOp = dyn_cast<func::FuncOp>(parentOp);
426 assert(funcOp && "BlockArgument is not part of a FuncOp");
427 // Update the ArgumentTypes of `funcOp` so that we can correctly get
428 // `getArgAttrDict` when resolving banking attributes across the iterations
429 // of creating new banks.
430 updateFuncOpArgumentTypes(funcOp, blockArgNum, newMemRefType, currFactor);
431 updateFuncOpArgAttrs(funcOp, blockArgNum, currFactor, remainingAttrs);
432 } else {
433 Operation *originalDef = originalMem.getDefiningOp();
434 Location loc = originalDef->getLoc();
435 builder.setInsertionPointAfter(originalDef);
436 TypeSwitch<Operation *>(originalDef)
437 .Case<memref::AllocOp>([&](memref::AllocOp allocOp) {
438 for (uint64_t bankCnt = 0; bankCnt < currFactor; ++bankCnt) {
439 auto bankAllocOp =
440 builder.create<memref::AllocOp>(loc, newMemRefType);
441 bankAllocOp->setAttrs(remainingAttrs);
442 banks.push_back(bankAllocOp);
443 }
444 })
445 .Case<memref::AllocaOp>([&](memref::AllocaOp allocaOp) {
446 for (uint64_t bankCnt = 0; bankCnt < currFactor; ++bankCnt) {
447 auto bankAllocaOp =
448 builder.create<memref::AllocaOp>(loc, newMemRefType);
449 bankAllocaOp->setAttrs(remainingAttrs);
450 banks.push_back(bankAllocaOp);
451 }
452 })
453 .Case<memref::GetGlobalOp>([&](memref::GetGlobalOp getGlobalOp) {
454 auto newBanks =
455 handleGetGlobalOp(getGlobalOp, currFactor, currDimension,
456 newMemRefType, builder, remainingAttrs);
457 banks.append(newBanks.begin(), newBanks.end());
458 })
459 .Default([](Operation *) {
460 llvm_unreachable("Unhandled memory operation type");
461 });
462 }
463
464 return banks;
465}
466
467// Replace the original load operations with newly created memory banks
469 : public OpRewritePattern<mlir::affine::AffineLoadOp> {
470 BankAffineLoadPattern(MLIRContext *context,
471 DenseMap<Value, SmallVector<Value>> &memoryToBanks,
472 DenseSet<Value> &oldMemRefVals)
473 : OpRewritePattern<mlir::affine::AffineLoadOp>(context),
475
476 LogicalResult matchAndRewrite(mlir::affine::AffineLoadOp loadOp,
477 PatternRewriter &rewriter) const override {
478 Location loc = loadOp.getLoc();
479 auto originalMem = loadOp.getMemref();
480 auto banks = memoryToBanks[originalMem];
481 auto loadIndices = loadOp.getIndices();
482 MemRefType originalMemRefType = loadOp.getMemRefType();
483 int64_t memrefRank = originalMemRefType.getRank();
484
485 BankingConfigAttributes currBankingConfig =
486 getMemRefBankingConfig(originalMem);
487 if (!currBankingConfig.factors) {
488 // No need to rewrite anymore.
489 return failure();
490 }
491
492 unsigned currFactor =
493 getCurrBankingInfo(currBankingConfig, bankingFactorsStr);
494 unsigned currDimension =
495 getCurrBankingInfo(currBankingConfig, bankingDimensionsStr);
496
497 verifyBankingConfigurations(currFactor, currDimension, originalMemRefType);
498
499 auto modMap = AffineMap::get(
500 /*dimCount=*/memrefRank, /*symbolCount=*/0,
501 {rewriter.getAffineDimExpr(currDimension) % currFactor});
502 auto divMap = AffineMap::get(
503 memrefRank, 0,
504 {rewriter.getAffineDimExpr(currDimension).floorDiv(currFactor)});
505
506 Value bankIndex =
507 rewriter.create<affine::AffineApplyOp>(loc, modMap, loadIndices);
508 Value offset =
509 rewriter.create<affine::AffineApplyOp>(loc, divMap, loadIndices);
510 SmallVector<Value, 4> newIndices(loadIndices.begin(), loadIndices.end());
511 newIndices[currDimension] = offset;
512
513 SmallVector<Type> resultTypes = {loadOp.getResult().getType()};
514
515 SmallVector<int64_t, 4> caseValues;
516 for (unsigned i = 0; i < currFactor; ++i)
517 caseValues.push_back(i);
518
519 rewriter.setInsertionPoint(loadOp);
520 scf::IndexSwitchOp switchOp = rewriter.create<scf::IndexSwitchOp>(
521 loc, resultTypes, bankIndex, caseValues,
522 /*numRegions=*/currFactor);
523
524 for (unsigned i = 0; i < currFactor; ++i) {
525 Region &caseRegion = switchOp.getCaseRegions()[i];
526 rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
527 Value bankedLoad = rewriter.create<mlir::affine::AffineLoadOp>(
528 loc, banks[i], newIndices);
529 rewriter.create<scf::YieldOp>(loc, bankedLoad);
530 }
531
532 Region &defaultRegion = switchOp.getDefaultRegion();
533 assert(defaultRegion.empty() && "Default region should be empty");
534 rewriter.setInsertionPointToStart(&defaultRegion.emplaceBlock());
535
536 TypedAttr zeroAttr =
537 cast<TypedAttr>(rewriter.getZeroAttr(loadOp.getType()));
538 auto defaultValue = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
539 rewriter.create<scf::YieldOp>(loc, defaultValue.getResult());
540
541 // We track Load's memory reference only if it is a block argument - this is
542 // the only case where the reference isn't replaced.
543 if (Value memRef = loadOp.getMemref(); isa<BlockArgument>(memRef)) {
544 oldMemRefVals.insert(memRef);
545 }
546 rewriter.replaceOp(loadOp, switchOp.getResult(0));
547
548 return success();
549 }
550
551private:
552 DenseMap<Value, SmallVector<Value>> &memoryToBanks;
553 DenseSet<Value> &oldMemRefVals;
554};
555
556// Replace the original store operations with newly created memory banks
558 : public OpRewritePattern<mlir::affine::AffineStoreOp> {
559 BankAffineStorePattern(MLIRContext *context,
560 DenseMap<Value, SmallVector<Value>> &memoryToBanks,
561 DenseSet<Operation *> &opsToErase,
562 DenseSet<Operation *> &processedOps,
563 DenseSet<Value> &oldMemRefVals)
564 : OpRewritePattern<mlir::affine::AffineStoreOp>(context),
567
568 LogicalResult matchAndRewrite(mlir::affine::AffineStoreOp storeOp,
569 PatternRewriter &rewriter) const override {
570 if (processedOps.contains(storeOp)) {
571 return failure();
572 }
573 auto currConfig = getMemRefBankingConfig(storeOp.getMemref());
574 if (!currConfig.factors) {
575 // No need to rewrite anymore.
576 return failure();
577 }
578 Location loc = storeOp.getLoc();
579 auto originalMem = storeOp.getMemref();
580 auto banks = memoryToBanks[originalMem];
581 auto storeIndices = storeOp.getIndices();
582 auto originalMemRefType = storeOp.getMemRefType();
583 int64_t memrefRank = originalMemRefType.getRank();
584
585 BankingConfigAttributes currBankingConfig =
586 getMemRefBankingConfig(originalMem);
587
588 unsigned currFactor =
589 getCurrBankingInfo(currBankingConfig, bankingFactorsStr);
590 unsigned currDimension =
591 getCurrBankingInfo(currBankingConfig, bankingDimensionsStr);
592
593 verifyBankingConfigurations(currFactor, currDimension, originalMemRefType);
594
595 auto modMap = AffineMap::get(
596 /*dimCount=*/memrefRank, /*symbolCount=*/0,
597 {rewriter.getAffineDimExpr(currDimension) % currFactor});
598 auto divMap = AffineMap::get(
599 memrefRank, 0,
600 {rewriter.getAffineDimExpr(currDimension).floorDiv(currFactor)});
601
602 Value bankIndex =
603 rewriter.create<affine::AffineApplyOp>(loc, modMap, storeIndices);
604 Value offset =
605 rewriter.create<affine::AffineApplyOp>(loc, divMap, storeIndices);
606 SmallVector<Value, 4> newIndices(storeIndices.begin(), storeIndices.end());
607 newIndices[currDimension] = offset;
608
609 SmallVector<Type> resultTypes = {};
610
611 SmallVector<int64_t, 4> caseValues;
612 for (unsigned i = 0; i < currFactor; ++i)
613 caseValues.push_back(i);
614
615 rewriter.setInsertionPoint(storeOp);
616 scf::IndexSwitchOp switchOp = rewriter.create<scf::IndexSwitchOp>(
617 loc, resultTypes, bankIndex, caseValues,
618 /*numRegions=*/currFactor);
619
620 for (unsigned i = 0; i < currFactor; ++i) {
621 Region &caseRegion = switchOp.getCaseRegions()[i];
622 rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
623 rewriter.create<mlir::affine::AffineStoreOp>(
624 loc, storeOp.getValueToStore(), banks[i], newIndices);
625 rewriter.create<scf::YieldOp>(loc);
626 }
627
628 Region &defaultRegion = switchOp.getDefaultRegion();
629 assert(defaultRegion.empty() && "Default region should be empty");
630 rewriter.setInsertionPointToStart(&defaultRegion.emplaceBlock());
631
632 rewriter.create<scf::YieldOp>(loc);
633
634 processedOps.insert(storeOp);
635 opsToErase.insert(storeOp);
636 oldMemRefVals.insert(storeOp.getMemref());
637
638 return success();
639 }
640
641private:
642 DenseMap<Value, SmallVector<Value>> &memoryToBanks;
643 DenseSet<Operation *> &opsToErase;
644 DenseSet<Operation *> &processedOps;
645 DenseSet<Value> &oldMemRefVals;
646};
647
648// Replace the original return operation with newly created memory banks
649struct BankReturnPattern : public OpRewritePattern<func::ReturnOp> {
650 BankReturnPattern(MLIRContext *context,
651 DenseMap<Value, SmallVector<Value>> &memoryToBanks)
652 : OpRewritePattern<func::ReturnOp>(context),
654
655 LogicalResult matchAndRewrite(func::ReturnOp returnOp,
656 PatternRewriter &rewriter) const override {
657 Location loc = returnOp.getLoc();
658 SmallVector<Value, 4> newReturnOperands;
659 bool allOrigMemsUsedByReturn = true;
660 for (auto operand : returnOp.getOperands()) {
661 if (!memoryToBanks.contains(operand)) {
662 newReturnOperands.push_back(operand);
663 continue;
664 }
665 if (operand.hasOneUse())
666 allOrigMemsUsedByReturn = false;
667 auto banks = memoryToBanks[operand];
668 newReturnOperands.append(banks.begin(), banks.end());
669 }
670
671 func::FuncOp funcOp = returnOp.getParentOp();
672 rewriter.setInsertionPointToEnd(&funcOp.getBlocks().front());
673 auto newReturnOp =
674 rewriter.create<func::ReturnOp>(loc, ValueRange(newReturnOperands));
675 TypeRange newReturnType = TypeRange(newReturnOperands);
676 FunctionType newFuncType =
677 FunctionType::get(funcOp.getContext(),
678 funcOp.getFunctionType().getInputs(), newReturnType);
679 funcOp.setType(newFuncType);
680
681 if (allOrigMemsUsedByReturn)
682 rewriter.replaceOp(returnOp, newReturnOp);
683
684 return success();
685 }
686
687private:
688 DenseMap<Value, SmallVector<Value>> &memoryToBanks;
689};
690
691// Clean up the empty uses old memory values by either erasing the defining
692// operation or replace the block arguments with new ones that corresponds to
693// the newly created banks. Change the function signature if the old memory
694// values are used as function arguments and/or return values.
695LogicalResult cleanUpOldMemRefs(DenseSet<Value> &oldMemRefVals,
696 DenseSet<Operation *> &opsToErase) {
697 DenseSet<func::FuncOp> funcsToModify;
698 SmallVector<Value, 4> valuesToErase;
699 DenseMap<func::FuncOp, SmallVector<unsigned, 4>> erasedArgIndices;
700 for (auto &memrefVal : oldMemRefVals) {
701 valuesToErase.push_back(memrefVal);
702 if (auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
703 if (auto funcOp =
704 dyn_cast<func::FuncOp>(blockArg.getOwner()->getParentOp())) {
705 funcsToModify.insert(funcOp);
706 erasedArgIndices[funcOp].push_back(blockArg.getArgNumber());
707 }
708 }
709 }
710
711 for (auto *op : opsToErase) {
712 op->erase();
713 }
714 // Erase values safely.
715 for (auto &memrefVal : valuesToErase) {
716 assert(memrefVal.use_empty() && "use must be empty");
717 if (auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
718 blockArg.getOwner()->eraseArgument(blockArg.getArgNumber());
719 } else if (auto *op = memrefVal.getDefiningOp()) {
720 op->erase();
721 }
722 }
723
724 // Modify the function argument attributes and function type accordingly
725 for (auto funcOp : funcsToModify) {
726 ArrayAttr existingArgAttrs = funcOp->getAttrOfType<ArrayAttr>("arg_attrs");
727 if (existingArgAttrs) {
728 SmallVector<Attribute, 4> updatedArgAttrs;
729 auto erasedIndices = erasedArgIndices[funcOp];
730 DenseSet<unsigned> indicesToErase(erasedIndices.begin(),
731 erasedIndices.end());
732 for (unsigned i = 0; i < existingArgAttrs.size(); ++i) {
733 if (!indicesToErase.contains(i))
734 updatedArgAttrs.push_back(existingArgAttrs[i]);
735 }
736
737 funcOp->setAttr("arg_attrs",
738 ArrayAttr::get(funcOp.getContext(), updatedArgAttrs));
739 }
740
741 SmallVector<Type, 4> newArgTypes;
742 for (BlockArgument arg : funcOp.getArguments()) {
743 newArgTypes.push_back(arg.getType());
744 }
745 FunctionType newFuncType =
746 FunctionType::get(funcOp.getContext(), newArgTypes,
747 funcOp.getFunctionType().getResults());
748 funcOp.setType(newFuncType);
749 }
750
751 return success();
752}
753
754void verifyBankingAttributesSize(Attribute bankingFactorsAttr,
755 Attribute bankingDimensionsAttr) {
756 if (auto factorsArrayAttr = dyn_cast<ArrayAttr>(bankingFactorsAttr)) {
757 assert(!factorsArrayAttr.empty() && "Banking factors should not be empty");
758 if (auto dimsArrayAttr = dyn_cast<ArrayAttr>(bankingDimensionsAttr)) {
759 assert(factorsArrayAttr.size() == dimsArrayAttr.size() &&
760 "Banking factors/dimensions must be paired together");
761 } else {
762 auto dimsIntAttr = dyn_cast<IntegerAttr>(bankingDimensionsAttr);
763 assert(dimsIntAttr && "banking.dimensions can either be an integer or an "
764 "array of integers");
765 assert(factorsArrayAttr.size() == 1 &&
766 "Banking factors/dimensions must be paired together");
767 }
768 } else {
769 auto factorsIntAttr = dyn_cast<IntegerAttr>(bankingFactorsAttr);
770 assert(factorsIntAttr &&
771 "banking.factors can either be an integer or an array of integers");
772 if (auto dimsArrayAttr = dyn_cast<ArrayAttr>(bankingDimensionsAttr)) {
773 assert(dimsArrayAttr.size() == 1 &&
774 "Banking factors/dimensions must be paired together");
775 } else {
776 auto dimsIntAttr = dyn_cast<IntegerAttr>(bankingDimensionsAttr);
777 assert(dimsIntAttr && "banking.dimensions can either be an integer or an "
778 "array of integers");
779 }
780 }
781}
782
783void MemoryBankingPass::setAllBankingAttributes(Operation *operation,
784 MLIRContext *context) {
785 ArrayAttr defaultFactorsAttr = ArrayAttr::get(
786 context,
787 llvm::map_to_vector(bankingFactors, [&](unsigned factor) -> Attribute {
788 return IntegerAttr::get(IntegerType::get(context, 32), factor);
789 }));
790
791 auto getDimensionsAttr =
792 [&](SmallVector<unsigned, 4> specifiedOrDefaultDims) -> ArrayAttr {
793 return ArrayAttr::get(
794 context, llvm::map_to_vector(specifiedOrDefaultDims,
795 [&](unsigned dim) -> Attribute {
796 return IntegerAttr::get(
797 IntegerType::get(context, 32), dim);
798 }));
799 };
800
801 // Set or keep the memory banking related attributes for every memory-involved
802 // affine operation.
803 operation->walk([&](affine::AffineParallelOp affineParallelOp) {
804 affineParallelOp.walk([&](Operation *op) {
805 if (!isa<affine::AffineWriteOpInterface, affine::AffineReadOpInterface>(
806 op))
807 return WalkResult::advance();
808
809 auto read = dyn_cast<affine::AffineReadOpInterface>(op);
810 Value memref = read
811 ? read.getMemRef()
812 : cast<affine::AffineWriteOpInterface>(op).getMemRef();
813 MemRefType memrefType =
814 read ? read.getMemRefType()
815 : cast<affine::AffineWriteOpInterface>(op).getMemRefType();
816
817 if (auto *originalDef = memref.getDefiningOp()) {
818 // Set the default factors using the command line option.
819 if (!originalDef->getAttr(bankingFactorsStr)) {
820 originalDef->setAttr(bankingFactorsStr, defaultFactorsAttr);
821 }
822
823 // Set the default `dimensions` either by the command line option or
824 // inferencing if unspecified.
825 if (!originalDef->getAttr(bankingDimensionsStr)) {
826 SmallVector<unsigned, 4> specifiedOrDefaultDims =
827 getSpecifiedOrDefaultBankingDim(bankingDimensions,
828 memrefType.getRank(),
829 memrefType.getShape());
830
831 originalDef->setAttr(bankingDimensionsStr,
832 getDimensionsAttr(specifiedOrDefaultDims));
833 }
834
835 verifyBankingAttributesSize(originalDef->getAttr(bankingFactorsStr),
836 originalDef->getAttr(bankingDimensionsStr));
837 } else if (isa<BlockArgument>(memref)) {
838 auto blockArg = cast<BlockArgument>(memref);
839 auto *parentOp = blockArg.getOwner()->getParentOp();
840 auto funcOp = dyn_cast<func::FuncOp>(parentOp);
841 assert(funcOp &&
842 "Expected the original memory to be a FuncOp block argument!");
843 unsigned argIndex = blockArg.getArgNumber();
844 SmallVector<unsigned, 4> specifiedOrDefaultDims =
846 bankingDimensions, memrefType.getRank(), memrefType.getShape());
847
848 if (!funcOp.getArgAttr(argIndex, bankingFactorsStr))
849 funcOp.setArgAttr(argIndex, bankingFactorsStr, defaultFactorsAttr);
850 if (!funcOp.getArgAttr(argIndex, bankingDimensionsStr))
851 funcOp.setArgAttr(argIndex, bankingDimensionsStr,
852 getDimensionsAttr(specifiedOrDefaultDims));
853
855 funcOp.getArgAttr(argIndex, bankingFactorsStr),
856 funcOp.getArgAttr(argIndex, bankingDimensionsStr));
857 }
858 return WalkResult::advance();
859 });
860 });
861}
862
863void MemoryBankingPass::runOnOperation() {
864 this->bankingFactors = {bankingFactorsList.begin(), bankingFactorsList.end()};
865 this->bankingDimensions = {bankingDimensionsList.begin(),
866 bankingDimensionsList.end()};
867
868 if (getOperation().isExternal() ||
869 (bankingFactors.empty() ||
870 std::all_of(bankingFactors.begin(), bankingFactors.end(),
871 [](unsigned f) { return f == 1; })))
872 return;
873
874 if (std::any_of(bankingFactors.begin(), bankingFactors.end(),
875 [](int f) { return f == 0; })) {
876 getOperation().emitError("banking factor must be greater than 1");
877 signalPassFailure();
878 return;
879 }
880
881 if (bankingDimensions.size() > bankingFactors.size()) {
882 getOperation().emitError(
883 "A banking dimension must be paired with a factor");
884 signalPassFailure();
885 return;
886 }
887 // `bankingFactors` is guaranteed to have elements and at least one of them is
888 // greater than 1 beyond this point.
889
890 setAllBankingAttributes(getOperation(), &getContext());
891
892 OpBuilder builder(getOperation());
893 // We run this pass until convergence, i.e., `applyMemoryBanking` has reached
894 // its fixed point, which means every memory read/write operation has been
895 // rewritten to be using the newly created banks, and that the old memory
896 // references are erased.
897 bool banksCreated = false;
898 do {
899 memoryToBanks.clear();
900 oldMemRefVals.clear();
901 opsToErase.clear();
902
903 banksCreated = false;
904 getOperation().walk([&](mlir::affine::AffineParallelOp parOp) {
905 DenseSet<Value> memrefsInPar = collectMemRefs(parOp);
906 // We run `createBanks` iff there exists some `memrefVal` s.t. it has
907 // banking attributes attached to it.
908 for (auto memrefVal : memrefsInPar) {
909 auto currConfig = getMemRefBankingConfig(memrefVal);
910 if (!currConfig.factors) {
911 continue;
912 }
913 auto [it, inserted] = memoryToBanks.insert(
914 std::make_pair(memrefVal, SmallVector<Value>{}));
915 if (inserted)
916 it->second = createBanks(builder, memrefVal);
917 banksCreated = true;
918 }
919 });
920
921 if (failed(applyMemoryBanking(getOperation(), &getContext()))) {
922 signalPassFailure();
923 break;
924 }
925 } while (banksCreated);
926}
927
928LogicalResult MemoryBankingPass::applyMemoryBanking(Operation *operation,
929 MLIRContext *ctx) {
930 RewritePatternSet patterns(ctx);
931
932 DenseSet<Operation *> processedOps;
933 patterns.add<BankAffineLoadPattern>(ctx, memoryToBanks, oldMemRefVals);
934 patterns.add<BankAffineStorePattern>(ctx, memoryToBanks, opsToErase,
935 processedOps, oldMemRefVals);
936 patterns.add<BankReturnPattern>(ctx, memoryToBanks);
937
938 GreedyRewriteConfig config;
939 config.strictMode = GreedyRewriteStrictness::ExistingOps;
940 if (failed(applyPatternsGreedily(operation, std::move(patterns), config))) {
941 return failure();
942 }
943
944 // Clean up the old memref values
945 if (failed(cleanUpOldMemRefs(oldMemRefVals, opsToErase))) {
946 return failure();
947 }
948
949 return success();
950}
951
952namespace circt {
953std::unique_ptr<mlir::Pass>
954createMemoryBankingPass(ArrayRef<unsigned> bankingFactors,
955 ArrayRef<unsigned> bankingDimensions) {
956 return std::make_unique<MemoryBankingPass>(bankingFactors, bankingDimensions);
957}
958} // namespace circt
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
Definition CHIRRTL.cpp:30
BankingConfigAttributes getMemRefBankingConfig(Value originalMem)
void verifyBankingAttributesSize(Attribute bankingFactorsAttr, Attribute bankingDimensionsAttr)
void verifyBankingConfigurations(unsigned bankingFactor, unsigned bankingDimension, MemRefType originalType)
unsigned getCurrBankingInfo(BankingConfigAttributes bankingConfigAttrs, StringRef attrName)
DenseSet< Value > collectMemRefs(affine::AffineParallelOp affineParallelOp)
Attribute getRemainingBankingInfo(MLIRContext *context, BankingConfigAttributes bankingConfigAttrs, StringRef attrName)
SmallVector< int64_t > decodeIndex(int64_t linIndex, ArrayRef< int64_t > shape)
LogicalResult cleanUpOldMemRefs(DenseSet< Value > &oldMemRefVals, DenseSet< Operation * > &opsToErase)
SmallVector< Value, 4 > handleGetGlobalOp(memref::GetGlobalOp getGlobalOp, uint64_t bankingFactor, unsigned bankingDimension, MemRefType newMemRefType, OpBuilder &builder, DictionaryAttr remainingAttrs)
void updateFuncOpArgumentTypes(func::FuncOp funcOp, unsigned argIndex, MemRefType newMemRefType, unsigned numInsertedArgs)
MemRefType computeBankedMemRefType(MemRefType originalType, uint64_t bankingFactor, unsigned bankingDimension)
SmallVector< unsigned, 4 > getSpecifiedOrDefaultBankingDim(const ArrayRef< unsigned > bankingDimensions, int64_t rank, ArrayRef< int64_t > shape)
void updateFuncOpArgAttrs(func::FuncOp funcOp, unsigned argIndex, unsigned numInsertedArgs, DictionaryAttr remainingAttrs)
SmallVector< SmallVector< Attribute > > sliceSubBlock(ArrayRef< Attribute > allAttrs, ArrayRef< int64_t > memShape, unsigned bankingDimension, unsigned bankingFactor)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createMemoryBankingPass(ArrayRef< unsigned > bankingFactors={}, ArrayRef< unsigned > bankingDimensions={})
DenseMap< Value, SmallVector< Value > > & memoryToBanks
LogicalResult matchAndRewrite(mlir::affine::AffineLoadOp loadOp, PatternRewriter &rewriter) const override
BankAffineLoadPattern(MLIRContext *context, DenseMap< Value, SmallVector< Value > > &memoryToBanks, DenseSet< Value > &oldMemRefVals)
DenseSet< Value > & oldMemRefVals
DenseMap< Value, SmallVector< Value > > & memoryToBanks
DenseSet< Value > & oldMemRefVals
DenseSet< Operation * > & processedOps
BankAffineStorePattern(MLIRContext *context, DenseMap< Value, SmallVector< Value > > &memoryToBanks, DenseSet< Operation * > &opsToErase, DenseSet< Operation * > &processedOps, DenseSet< Value > &oldMemRefVals)
DenseSet< Operation * > & opsToErase
LogicalResult matchAndRewrite(mlir::affine::AffineStoreOp storeOp, PatternRewriter &rewriter) const override
DenseMap< Value, SmallVector< Value > > & memoryToBanks
BankReturnPattern(MLIRContext *context, DenseMap< Value, SmallVector< Value > > &memoryToBanks)
LogicalResult matchAndRewrite(func::ReturnOp returnOp, PatternRewriter &rewriter) const override