Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
MemoryAllocationPass.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/Dialect/Index/IR/IndexOps.h"
13#include "mlir/IR/PatternMatch.h"
14
15namespace circt {
16namespace rtg {
17#define GEN_PASS_DEF_MEMORYALLOCATIONPASS
18#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
19} // namespace rtg
20} // namespace circt
21
22using namespace mlir;
23using namespace circt;
24using namespace circt::rtg;
25
26//===----------------------------------------------------------------------===//
27// Memory Allocation Pass
28//===----------------------------------------------------------------------===//
29
30namespace {
31struct AllocationInfo {
32 APInt nextFree;
33 APInt maxAddr;
34};
35
36/// Helper function to adjust APInt width and check for truncation errors.
37LogicalResult adjustAPIntWidth(APInt &value, unsigned targetBitWidth,
38 Location loc) {
39 if (value.getBitWidth() > targetBitWidth && !value.isIntN(targetBitWidth))
40 return mlir::emitError(
41 loc, "cannot truncate APInt because value is too big to fit");
42
43 if (value.getBitWidth() < targetBitWidth) {
44 value = value.zext(targetBitWidth);
45 return success();
46 }
47
48 value = value.trunc(targetBitWidth);
49 return success();
50}
51
52struct MemoryAllocationPass
53 : public rtg::impl::MemoryAllocationPassBase<MemoryAllocationPass> {
54 using Base::Base;
55
56 void runOnOperation() override;
57};
58} // namespace
59
60void MemoryAllocationPass::runOnOperation() {
61 auto testOp = getOperation();
62 DenseMap<Value, AllocationInfo> nextFreeMap;
63
64 if (!useImmediates) {
65 testOp->emitError("label mode not yet supported");
66 return signalPassFailure();
67 }
68
69 // Collect memory block declarations in target.
70 auto target = testOp.getTargetAttr();
71 if (!target)
72 return;
73
74 SymbolTable table(testOp->getParentOfType<ModuleOp>());
75 auto targetOp = table.lookupNearestSymbolFrom<TargetOp>(testOp, target);
76
77 for (auto &op : *targetOp.getBody()) {
78 auto memBlock = dyn_cast<MemoryBlockDeclareOp>(&op);
79 if (!memBlock)
80 continue;
81
82 auto &slot = nextFreeMap[memBlock.getResult()];
83 slot.nextFree = memBlock.getBaseAddress();
84 slot.maxAddr = memBlock.getEndAddress();
85 }
86
87 // Propagate memory block declarations from target to test.
88 auto targetYields = targetOp.getBody()->getTerminator()->getOperands();
89 auto targetEntries = targetOp.getTarget().getEntries();
90 auto testEntries = testOp.getTargetType().getEntries();
91 auto testArgs = testOp.getBody()->getArguments();
92
93 size_t targetIdx = 0;
94 for (auto [testEntry, testArg] : llvm::zip(testEntries, testArgs)) {
95 while (targetIdx < targetEntries.size() &&
96 targetEntries[targetIdx].name.getValue() < testEntry.name.getValue())
97 targetIdx++;
98
99 if (targetIdx < targetEntries.size() &&
100 targetEntries[targetIdx].name.getValue() == testEntry.name.getValue()) {
101 auto targetYield = targetYields[targetIdx];
102 auto it = nextFreeMap.find(targetYield);
103 if (it != nextFreeMap.end())
104 nextFreeMap[testArg] = it->second;
105 }
106 }
107
108 // Iterate through the test and allocate memory for each 'memory_alloc'
109 // operation.
110 for (auto &op : llvm::make_early_inc_range(*testOp.getBody())) {
111 auto mem = dyn_cast<MemoryAllocOp>(&op);
112 if (!mem)
113 continue;
114
115 auto iter = nextFreeMap.find(mem.getMemoryBlock());
116 if (iter == nextFreeMap.end()) {
117 mem->emitError("memory block not found");
118 return signalPassFailure();
119 }
120
121 auto sizeOp = mem.getSize().getDefiningOp<index::ConstantOp>();
122 if (!sizeOp) {
123 mem->emitError("could not determine memory allocation size");
124 return signalPassFailure();
125 }
126
127 auto alignOp = mem.getAlignment().getDefiningOp<index::ConstantOp>();
128 if (!alignOp) {
129 mem->emitError("could not determine memory allocation alignment");
130 return signalPassFailure();
131 }
132
133 APInt size = sizeOp.getValue();
134 APInt alignment = alignOp.getValue();
135
136 if (size.isZero()) {
137 mem->emitError(
138 "memory allocation size must be greater than zero (was 0)");
139 return signalPassFailure();
140 }
141
142 if (!alignment.isPowerOf2()) {
143 mem->emitError("memory allocation alignment must be a power of two (was ")
144 << alignment.getZExtValue() << ")";
145 return signalPassFailure();
146 }
147
148 auto &memBlock = iter->getSecond();
149 APInt nextFree = memBlock.nextFree;
150 unsigned bitWidth = nextFree.getBitWidth();
151
152 if (failed(adjustAPIntWidth(size, bitWidth, mem.getLoc())) ||
153 failed(adjustAPIntWidth(alignment, bitWidth, mem.getLoc())))
154 return signalPassFailure();
155
156 // Calculate aligned address
157 APInt bias(bitWidth, !nextFree.isZero());
158 APInt ceilDiv = (nextFree - bias).udiv(alignment) + bias;
159 APInt nextFreeAligned = ceilDiv * alignment;
160
161 memBlock.nextFree = nextFreeAligned + size;
162 if (memBlock.nextFree.ugt(memBlock.maxAddr)) {
163 mem->emitError("memory block not large enough to fit all allocations");
164 return signalPassFailure();
165 }
166
167 ++numMemoriesAllocated;
168
169 IRRewriter builder(mem);
170 builder.replaceOpWithNewOp<ConstantOp>(
171 mem, ImmediateAttr::get(builder.getContext(), nextFreeAligned));
172 }
173}
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition rtg.py:1