CIRCT 23.0.0git
Loading...
Searching...
No Matches
MemoryAllocationPass.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/Dialect/Index/IR/IndexOps.h"
13#include "mlir/IR/Matchers.h"
14#include "mlir/IR/PatternMatch.h"
15
16namespace circt {
17namespace rtg {
18#define GEN_PASS_DEF_MEMORYALLOCATIONPASS
19#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
20} // namespace rtg
21} // namespace circt
22
23using namespace mlir;
24using namespace circt;
25using namespace circt::rtg;
26
27//===----------------------------------------------------------------------===//
28// Memory Allocation Pass
29//===----------------------------------------------------------------------===//
30
31namespace {
32struct AllocationInfo {
33 APInt nextFree;
34 APInt maxAddr;
35};
36
37/// Helper function to adjust APInt width and check for truncation errors.
38LogicalResult adjustAPIntWidth(APInt &value, unsigned targetBitWidth,
39 Location loc) {
40 if (value.getBitWidth() > targetBitWidth && !value.isIntN(targetBitWidth))
41 return mlir::emitError(
42 loc, "cannot truncate APInt because value is too big to fit");
43
44 if (value.getBitWidth() < targetBitWidth) {
45 value = value.zext(targetBitWidth);
46 return success();
47 }
48
49 value = value.trunc(targetBitWidth);
50 return success();
51}
52
53struct MemoryAllocationPass
54 : public rtg::impl::MemoryAllocationPassBase<MemoryAllocationPass> {
55 using Base::Base;
56
57 void runOnOperation() override;
58};
59} // namespace
60
61void MemoryAllocationPass::runOnOperation() {
62 auto testOp = getOperation();
63 DenseMap<Value, AllocationInfo> nextFreeMap;
64
65 if (!useImmediates) {
66 testOp->emitError("label mode not yet supported");
67 return signalPassFailure();
68 }
69
70 // Collect memory block declarations in target.
71 auto target = testOp.getTargetAttr();
72 if (!target)
73 return;
74
75 SymbolTable table(testOp->getParentOfType<ModuleOp>());
76 auto targetOp = table.lookupNearestSymbolFrom<TargetOp>(testOp, target);
77
78 for (auto &op : *targetOp.getBody()) {
79 auto memBlock = dyn_cast<MemoryBlockDeclareOp>(&op);
80 if (!memBlock)
81 continue;
82
83 auto &slot = nextFreeMap[memBlock.getResult()];
84 slot.nextFree = memBlock.getBaseAddress();
85 slot.maxAddr = memBlock.getEndAddress();
86 }
87
88 // Propagate memory block declarations from target to test.
89 auto targetYields = targetOp.getBody()->getTerminator()->getOperands();
90 auto targetEntries = targetOp.getTarget().getEntries();
91 auto testEntries = testOp.getTargetType().getEntries();
92 auto testArgs = testOp.getBody()->getArguments();
93
94 size_t targetIdx = 0;
95 for (auto [testEntry, testArg] : llvm::zip(testEntries, testArgs)) {
96 while (targetIdx < targetEntries.size() &&
97 targetEntries[targetIdx].name.getValue() < testEntry.name.getValue())
98 targetIdx++;
99
100 if (targetIdx < targetEntries.size() &&
101 targetEntries[targetIdx].name.getValue() == testEntry.name.getValue()) {
102 auto targetYield = targetYields[targetIdx];
103 auto it = nextFreeMap.find(targetYield);
104 if (it != nextFreeMap.end())
105 nextFreeMap[testArg] = it->second;
106 }
107 }
108
109 // Iterate through the test and allocate memory for each 'memory_alloc'
110 // operation.
111 for (auto &op : llvm::make_early_inc_range(*testOp.getBody())) {
112 auto mem = dyn_cast<MemoryAllocOp>(&op);
113 if (!mem)
114 continue;
115
116 auto iter = nextFreeMap.find(mem.getMemoryBlock());
117 if (iter == nextFreeMap.end()) {
118 mem->emitError("memory block not found");
119 return signalPassFailure();
120 }
121
122 APInt size;
123 if (!matchPattern(mem.getSize(), m_ConstantInt(&size))) {
124 mem->emitError("could not determine memory allocation size");
125 return signalPassFailure();
126 }
127
128 APInt alignment;
129 if (!matchPattern(mem.getAlignment(), m_ConstantInt(&alignment))) {
130 mem->emitError("could not determine memory allocation alignment");
131 return signalPassFailure();
132 }
133
134 if (size.isZero()) {
135 mem->emitError(
136 "memory allocation size must be greater than zero (was 0)");
137 return signalPassFailure();
138 }
139
140 if (!alignment.isPowerOf2()) {
141 mem->emitError("memory allocation alignment must be a power of two (was ")
142 << alignment.getZExtValue() << ")";
143 return signalPassFailure();
144 }
145
146 auto &memBlock = iter->getSecond();
147 APInt nextFree = memBlock.nextFree;
148 unsigned bitWidth = nextFree.getBitWidth();
149
150 if (failed(adjustAPIntWidth(size, bitWidth, mem.getLoc())) ||
151 failed(adjustAPIntWidth(alignment, bitWidth, mem.getLoc())))
152 return signalPassFailure();
153
154 // Calculate aligned address
155 APInt bias(bitWidth, !nextFree.isZero());
156 APInt ceilDiv = (nextFree - bias).udiv(alignment) + bias;
157 APInt nextFreeAligned = ceilDiv * alignment;
158
159 memBlock.nextFree = nextFreeAligned + size;
160 if (memBlock.nextFree.ugt(memBlock.maxAddr)) {
161 mem->emitError("memory block not large enough to fit all allocations");
162 return signalPassFailure();
163 }
164
165 ++numMemoriesAllocated;
166
167 IRRewriter builder(mem);
168 builder.replaceOpWithNewOp<ConstantOp>(
169 mem, ImmediateAttr::get(builder.getContext(), nextFreeAligned));
170 }
171}
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition rtg.py:1