CIRCT  20.0.0git
MemoryBanking.cpp
Go to the documentation of this file.
1 //===- MemoryBanking.cpp - memory bank parallel loops -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements parallel loop memory banking.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "circt/Support/LLVM.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/SCF/IR/SCF.h"
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/OperationSupport.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Support/LLVM.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/raw_ostream.h"
28 
29 namespace circt {
30 #define GEN_PASS_DEF_MEMORYBANKING
31 #include "circt/Transforms/Passes.h.inc"
32 } // namespace circt
33 
34 using namespace mlir;
35 using namespace circt;
36 
37 namespace {
38 
39 /// Partition memories used in `affine.parallel` operation by the
40 /// `bankingFactor` throughout the program.
41 struct MemoryBankingPass
42  : public circt::impl::MemoryBankingBase<MemoryBankingPass> {
43  MemoryBankingPass(const MemoryBankingPass &other) = default;
44  explicit MemoryBankingPass(
45  std::optional<unsigned> bankingFactor = std::nullopt) {}
46 
47  void runOnOperation() override;
48 
49 private:
50  // map from original memory definition to newly allocated banks
51  DenseMap<Value, SmallVector<Value>> memoryToBanks;
52  DenseSet<Operation *> opsToErase;
53 };
54 } // namespace
55 
56 // Collect all memref in the `parOp`'s region'
57 DenseSet<Value> collectMemRefs(mlir::affine::AffineParallelOp parOp) {
58  DenseSet<Value> memrefVals;
59  parOp.walk([&](Operation *op) {
60  for (auto operand : op->getOperands()) {
61  if (isa<MemRefType>(operand.getType()))
62  memrefVals.insert(operand);
63  }
64  return WalkResult::advance();
65  });
66  return memrefVals;
67 }
68 
69 MemRefType computeBankedMemRefType(MemRefType originalType,
70  uint64_t bankingFactor) {
71  ArrayRef<int64_t> originalShape = originalType.getShape();
72  assert(!originalShape.empty() && "memref shape should not be empty");
73  assert(originalType.getRank() == 1 &&
74  "currently only support one dimension memories");
75  SmallVector<int64_t, 4> newShape(originalShape.begin(), originalShape.end());
76  assert(newShape.front() % bankingFactor == 0 &&
77  "memref shape must be evenly divided by the banking factor");
78  newShape.front() /= bankingFactor;
79  MemRefType newMemRefType =
80  MemRefType::get(newShape, originalType.getElementType(),
81  originalType.getLayout(), originalType.getMemorySpace());
82 
83  return newMemRefType;
84 }
85 
86 SmallVector<Value, 4> createBanks(Value originalMem, uint64_t bankingFactor) {
87  MemRefType originalMemRefType = cast<MemRefType>(originalMem.getType());
88  MemRefType newMemRefType =
89  computeBankedMemRefType(originalMemRefType, bankingFactor);
90  SmallVector<Value, 4> banks;
91  if (auto blockArgMem = dyn_cast<BlockArgument>(originalMem)) {
92  Block *block = blockArgMem.getOwner();
93  unsigned blockArgNum = blockArgMem.getArgNumber();
94 
95  SmallVector<Type> banksType;
96  for (unsigned i = 0; i < bankingFactor; ++i) {
97  block->insertArgument(blockArgNum + 1 + i, newMemRefType,
98  blockArgMem.getLoc());
99  }
100 
101  auto blockArgs =
102  block->getArguments().slice(blockArgNum + 1, bankingFactor);
103  banks.append(blockArgs.begin(), blockArgs.end());
104  } else {
105  Operation *originalDef = originalMem.getDefiningOp();
106  Location loc = originalDef->getLoc();
107  OpBuilder builder(originalDef);
108  builder.setInsertionPointAfter(originalDef);
109  TypeSwitch<Operation *>(originalDef)
110  .Case<memref::AllocOp>([&](memref::AllocOp allocOp) {
111  for (uint64_t bankCnt = 0; bankCnt < bankingFactor; ++bankCnt) {
112  auto bankAllocOp =
113  builder.create<memref::AllocOp>(loc, newMemRefType);
114  banks.push_back(bankAllocOp);
115  }
116  })
117  .Case<memref::AllocaOp>([&](memref::AllocaOp allocaOp) {
118  for (uint64_t bankCnt = 0; bankCnt < bankingFactor; ++bankCnt) {
119  auto bankAllocaOp =
120  builder.create<memref::AllocaOp>(loc, newMemRefType);
121  banks.push_back(bankAllocaOp);
122  }
123  })
124  .Default([](Operation *) {
125  llvm_unreachable("Unhandled memory operation type");
126  });
127  }
128  return banks;
129 }
130 
131 // Replace the original load operations with newly created memory banks
133  : public OpRewritePattern<mlir::affine::AffineLoadOp> {
134  BankAffineLoadPattern(MLIRContext *context, uint64_t bankingFactor,
135  DenseMap<Value, SmallVector<Value>> &memoryToBanks)
136  : OpRewritePattern<mlir::affine::AffineLoadOp>(context),
137  bankingFactor(bankingFactor), memoryToBanks(memoryToBanks) {}
138 
139  LogicalResult matchAndRewrite(mlir::affine::AffineLoadOp loadOp,
140  PatternRewriter &rewriter) const override {
141  Location loc = loadOp.getLoc();
142  auto banks = memoryToBanks[loadOp.getMemref()];
143  Value loadIndex = loadOp.getIndices().front();
144  auto modMap =
145  AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0) % bankingFactor});
146  auto divMap = AffineMap::get(
147  1, 0, {rewriter.getAffineDimExpr(0).floorDiv(bankingFactor)});
148 
149  Value bankIndex = rewriter.create<affine::AffineApplyOp>(
150  loc, modMap, loadIndex); // assuming one-dim
151  Value offset =
152  rewriter.create<affine::AffineApplyOp>(loc, divMap, loadIndex);
153 
154  SmallVector<Type> resultTypes = {loadOp.getResult().getType()};
155 
156  SmallVector<int64_t, 4> caseValues;
157  for (unsigned i = 0; i < bankingFactor; ++i)
158  caseValues.push_back(i);
159 
160  rewriter.setInsertionPoint(loadOp);
161  scf::IndexSwitchOp switchOp = rewriter.create<scf::IndexSwitchOp>(
162  loc, resultTypes, bankIndex, caseValues,
163  /*numRegions=*/bankingFactor);
164 
165  for (unsigned i = 0; i < bankingFactor; ++i) {
166  Region &caseRegion = switchOp.getCaseRegions()[i];
167  rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
168  Value bankedLoad =
169  rewriter.create<mlir::affine::AffineLoadOp>(loc, banks[i], offset);
170  rewriter.create<scf::YieldOp>(loc, bankedLoad);
171  }
172 
173  Region &defaultRegion = switchOp.getDefaultRegion();
174  assert(defaultRegion.empty() && "Default region should be empty");
175  rewriter.setInsertionPointToStart(&defaultRegion.emplaceBlock());
176 
177  TypedAttr zeroAttr =
178  cast<TypedAttr>(rewriter.getZeroAttr(loadOp.getType()));
179  auto defaultValue = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
180  rewriter.create<scf::YieldOp>(loc, defaultValue.getResult());
181 
182  rewriter.replaceOp(loadOp, switchOp.getResult(0));
183 
184  return success();
185  }
186 
187 private:
188  uint64_t bankingFactor;
189  DenseMap<Value, SmallVector<Value>> &memoryToBanks;
190 };
191 
192 // Replace the original store operations with newly created memory banks
194  : public OpRewritePattern<mlir::affine::AffineStoreOp> {
195  BankAffineStorePattern(MLIRContext *context, uint64_t bankingFactor,
196  DenseMap<Value, SmallVector<Value>> &memoryToBanks,
197  DenseSet<Operation *> &opsToErase,
198  DenseSet<Operation *> &processedOps)
199  : OpRewritePattern<mlir::affine::AffineStoreOp>(context),
200  bankingFactor(bankingFactor), memoryToBanks(memoryToBanks),
201  opsToErase(opsToErase), processedOps(processedOps) {}
202 
203  LogicalResult matchAndRewrite(mlir::affine::AffineStoreOp storeOp,
204  PatternRewriter &rewriter) const override {
205  if (processedOps.contains(storeOp)) {
206  return failure();
207  }
208  Location loc = storeOp.getLoc();
209  auto banks = memoryToBanks[storeOp.getMemref()];
210  Value storeIndex = storeOp.getIndices().front();
211 
212  auto modMap =
213  AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0) % bankingFactor});
214  auto divMap = AffineMap::get(
215  1, 0, {rewriter.getAffineDimExpr(0).floorDiv(bankingFactor)});
216 
217  Value bankIndex = rewriter.create<affine::AffineApplyOp>(
218  loc, modMap, storeIndex); // assuming one-dim
219  Value offset =
220  rewriter.create<affine::AffineApplyOp>(loc, divMap, storeIndex);
221 
222  SmallVector<Type> resultTypes = {};
223 
224  SmallVector<int64_t, 4> caseValues;
225  for (unsigned i = 0; i < bankingFactor; ++i)
226  caseValues.push_back(i);
227 
228  rewriter.setInsertionPoint(storeOp);
229  scf::IndexSwitchOp switchOp = rewriter.create<scf::IndexSwitchOp>(
230  loc, resultTypes, bankIndex, caseValues,
231  /*numRegions=*/bankingFactor);
232 
233  for (unsigned i = 0; i < bankingFactor; ++i) {
234  Region &caseRegion = switchOp.getCaseRegions()[i];
235  rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
236  rewriter.create<mlir::affine::AffineStoreOp>(
237  loc, storeOp.getValueToStore(), banks[i], offset);
238  rewriter.create<scf::YieldOp>(loc);
239  }
240 
241  Region &defaultRegion = switchOp.getDefaultRegion();
242  assert(defaultRegion.empty() && "Default region should be empty");
243  rewriter.setInsertionPointToStart(&defaultRegion.emplaceBlock());
244 
245  rewriter.create<scf::YieldOp>(loc);
246 
247  processedOps.insert(storeOp);
248  opsToErase.insert(storeOp);
249 
250  return success();
251  }
252 
253 private:
254  uint64_t bankingFactor;
255  DenseMap<Value, SmallVector<Value>> &memoryToBanks;
256  DenseSet<Operation *> &opsToErase;
257  DenseSet<Operation *> &processedOps;
258 };
259 
260 // Replace the original return operation with newly created memory banks
261 struct BankReturnPattern : public OpRewritePattern<func::ReturnOp> {
262  BankReturnPattern(MLIRContext *context,
263  DenseMap<Value, SmallVector<Value>> &memoryToBanks)
264  : OpRewritePattern<func::ReturnOp>(context),
265  memoryToBanks(memoryToBanks) {}
266 
267  LogicalResult matchAndRewrite(func::ReturnOp returnOp,
268  PatternRewriter &rewriter) const override {
269  Location loc = returnOp.getLoc();
270  SmallVector<Value, 4> newReturnOperands;
271  bool allOrigMemsUsedByReturn = true;
272  for (auto operand : returnOp.getOperands()) {
273  if (!memoryToBanks.contains(operand)) {
274  newReturnOperands.push_back(operand);
275  continue;
276  }
277  if (operand.hasOneUse())
278  allOrigMemsUsedByReturn = false;
279  auto banks = memoryToBanks[operand];
280  newReturnOperands.append(banks.begin(), banks.end());
281  }
282 
283  func::FuncOp funcOp = returnOp.getParentOp();
284  rewriter.setInsertionPointToEnd(&funcOp.getBlocks().front());
285  auto newReturnOp =
286  rewriter.create<func::ReturnOp>(loc, ValueRange(newReturnOperands));
287  TypeRange newReturnType = TypeRange(newReturnOperands);
288  FunctionType newFuncType =
289  FunctionType::get(funcOp.getContext(),
290  funcOp.getFunctionType().getInputs(), newReturnType);
291  funcOp.setType(newFuncType);
292 
293  if (allOrigMemsUsedByReturn)
294  rewriter.replaceOp(returnOp, newReturnOp);
295 
296  return success();
297  }
298 
299 private:
300  DenseMap<Value, SmallVector<Value>> &memoryToBanks;
301 };
302 
303 // Clean up the empty uses old memory values by either erasing the defining
304 // operation or replace the block arguments with new ones that corresponds to
305 // the newly created banks. Change the function signature if the old memory
306 // values are used as function arguments and/or return values.
307 LogicalResult cleanUpOldMemRefs(DenseSet<Value> &oldMemRefVals,
308  DenseSet<Operation *> &opsToErase) {
309  DenseSet<func::FuncOp> funcsToModify;
310  SmallVector<Value, 4> valuesToErase;
311  for (auto &memrefVal : oldMemRefVals) {
312  valuesToErase.push_back(memrefVal);
313  if (auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
314  if (auto funcOp =
315  dyn_cast<func::FuncOp>(blockArg.getOwner()->getParentOp()))
316  funcsToModify.insert(funcOp);
317  }
318  }
319 
320  for (auto *op : opsToErase) {
321  op->erase();
322  }
323  // Erase values safely.
324  for (auto &memrefVal : valuesToErase) {
325  assert(memrefVal.use_empty() && "use must be empty");
326  if (auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
327  blockArg.getOwner()->eraseArgument(blockArg.getArgNumber());
328  } else if (auto *op = memrefVal.getDefiningOp()) {
329  op->erase();
330  }
331  }
332 
333  // Modify the function type accordingly
334  for (auto funcOp : funcsToModify) {
335  SmallVector<Type, 4> newArgTypes;
336  for (BlockArgument arg : funcOp.getArguments()) {
337  newArgTypes.push_back(arg.getType());
338  }
339  FunctionType newFuncType =
340  FunctionType::get(funcOp.getContext(), newArgTypes,
341  funcOp.getFunctionType().getResults());
342  funcOp.setType(newFuncType);
343  }
344 
345  return success();
346 }
347 
348 void MemoryBankingPass::runOnOperation() {
349  if (getOperation().isExternal() || bankingFactor == 1)
350  return;
351 
352  if (bankingFactor == 0) {
353  getOperation().emitError("banking factor must be greater than 1");
354  signalPassFailure();
355  return;
356  }
357 
358  getOperation().walk([&](mlir::affine::AffineParallelOp parOp) {
359  DenseSet<Value> memrefsInPar = collectMemRefs(parOp);
360 
361  for (auto memrefVal : memrefsInPar)
362  memoryToBanks[memrefVal] = createBanks(memrefVal, bankingFactor);
363  });
364 
365  auto *ctx = &getContext();
366  RewritePatternSet patterns(ctx);
367 
368  DenseSet<Operation *> processedOps;
369  patterns.add<BankAffineLoadPattern>(ctx, bankingFactor, memoryToBanks);
370  patterns.add<BankAffineStorePattern>(ctx, bankingFactor, memoryToBanks,
371  opsToErase, processedOps);
372  patterns.add<BankReturnPattern>(ctx, memoryToBanks);
373 
374  GreedyRewriteConfig config;
375  config.strictMode = GreedyRewriteStrictness::ExistingOps;
376  if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
377  config))) {
378  signalPassFailure();
379  }
380 
381  // Clean up the old memref values
382  DenseSet<Value> oldMemRefVals;
383  for (const auto &[memory, _] : memoryToBanks)
384  oldMemRefVals.insert(memory);
385 
386  if (failed(cleanUpOldMemRefs(oldMemRefVals, opsToErase))) {
387  signalPassFailure();
388  }
389 }
390 
391 namespace circt {
392 std::unique_ptr<mlir::Pass>
393 createMemoryBankingPass(std::optional<unsigned> bankingFactor) {
394  return std::make_unique<MemoryBankingPass>(bankingFactor);
395 }
396 } // namespace circt
assert(baseType &&"element must be base type")
SmallVector< Value, 4 > createBanks(Value originalMem, uint64_t bankingFactor)
DenseSet< Value > collectMemRefs(mlir::affine::AffineParallelOp parOp)
LogicalResult cleanUpOldMemRefs(DenseSet< Value > &oldMemRefVals, DenseSet< Operation * > &opsToErase)
MemRefType computeBankedMemRefType(MemRefType originalType, uint64_t bankingFactor)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
std::unique_ptr< mlir::Pass > createMemoryBankingPass(std::optional< unsigned > bankingFactor=std::nullopt)
DenseMap< Value, SmallVector< Value > > & memoryToBanks
BankAffineLoadPattern(MLIRContext *context, uint64_t bankingFactor, DenseMap< Value, SmallVector< Value >> &memoryToBanks)
LogicalResult matchAndRewrite(mlir::affine::AffineLoadOp loadOp, PatternRewriter &rewriter) const override
DenseMap< Value, SmallVector< Value > > & memoryToBanks
DenseSet< Operation * > & processedOps
BankAffineStorePattern(MLIRContext *context, uint64_t bankingFactor, DenseMap< Value, SmallVector< Value >> &memoryToBanks, DenseSet< Operation * > &opsToErase, DenseSet< Operation * > &processedOps)
DenseSet< Operation * > & opsToErase
LogicalResult matchAndRewrite(mlir::affine::AffineStoreOp storeOp, PatternRewriter &rewriter) const override
DenseMap< Value, SmallVector< Value > > & memoryToBanks
LogicalResult matchAndRewrite(func::ReturnOp returnOp, PatternRewriter &rewriter) const override
BankReturnPattern(MLIRContext *context, DenseMap< Value, SmallVector< Value >> &memoryToBanks)