12#include "mlir/Dialect/Index/IR/IndexOps.h"
13#include "mlir/IR/Matchers.h"
14#include "mlir/IR/PatternMatch.h"
18#define GEN_PASS_DEF_MEMORYALLOCATIONPASS
19#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
32struct AllocationInfo {
38LogicalResult adjustAPIntWidth(APInt &value,
unsigned targetBitWidth,
40 if (value.getBitWidth() > targetBitWidth && !value.isIntN(targetBitWidth))
41 return mlir::emitError(
42 loc,
"cannot truncate APInt because value is too big to fit");
44 if (value.getBitWidth() < targetBitWidth) {
45 value = value.zext(targetBitWidth);
49 value = value.trunc(targetBitWidth);
53struct MemoryAllocationPass
54 :
public rtg::impl::MemoryAllocationPassBase<MemoryAllocationPass> {
57 void runOnOperation()
override;
61void MemoryAllocationPass::runOnOperation() {
62 auto testOp = getOperation();
63 DenseMap<Value, AllocationInfo> nextFreeMap;
66 testOp->emitError(
"label mode not yet supported");
67 return signalPassFailure();
71 auto target = testOp.getTargetAttr();
75 SymbolTable table(testOp->getParentOfType<ModuleOp>());
76 auto targetOp = table.lookupNearestSymbolFrom<TargetOp>(testOp, target);
78 for (
auto &op : *targetOp.getBody()) {
79 auto memBlock = dyn_cast<MemoryBlockDeclareOp>(&op);
83 auto &slot = nextFreeMap[memBlock.getResult()];
84 slot.nextFree = memBlock.getBaseAddress();
85 slot.maxAddr = memBlock.getEndAddress();
89 auto targetYields = targetOp.getBody()->getTerminator()->getOperands();
90 auto targetEntries = targetOp.getTarget().getEntries();
91 auto testEntries = testOp.getTargetType().getEntries();
92 auto testArgs = testOp.getBody()->getArguments();
95 for (
auto [testEntry, testArg] :
llvm::zip(testEntries, testArgs)) {
96 while (targetIdx < targetEntries.size() &&
97 targetEntries[targetIdx].name.getValue() < testEntry.name.getValue())
100 if (targetIdx < targetEntries.size() &&
101 targetEntries[targetIdx].name.getValue() == testEntry.name.getValue()) {
102 auto targetYield = targetYields[targetIdx];
103 auto it = nextFreeMap.find(targetYield);
104 if (it != nextFreeMap.end())
105 nextFreeMap[testArg] = it->second;
111 for (
auto &op :
llvm::make_early_inc_range(*testOp.getBody())) {
112 auto mem = dyn_cast<MemoryAllocOp>(&op);
116 auto iter = nextFreeMap.find(mem.getMemoryBlock());
117 if (iter == nextFreeMap.end()) {
118 mem->emitError(
"memory block not found");
119 return signalPassFailure();
123 if (!matchPattern(mem.getSize(), m_ConstantInt(&size))) {
124 mem->emitError(
"could not determine memory allocation size");
125 return signalPassFailure();
129 if (!matchPattern(mem.getAlignment(), m_ConstantInt(&alignment))) {
130 mem->emitError(
"could not determine memory allocation alignment");
131 return signalPassFailure();
136 "memory allocation size must be greater than zero (was 0)");
137 return signalPassFailure();
140 if (!alignment.isPowerOf2()) {
141 mem->emitError(
"memory allocation alignment must be a power of two (was ")
142 << alignment.getZExtValue() <<
")";
143 return signalPassFailure();
146 auto &memBlock = iter->getSecond();
147 APInt nextFree = memBlock.nextFree;
148 unsigned bitWidth = nextFree.getBitWidth();
150 if (failed(adjustAPIntWidth(size, bitWidth, mem.getLoc())) ||
151 failed(adjustAPIntWidth(alignment, bitWidth, mem.getLoc())))
152 return signalPassFailure();
155 APInt bias(bitWidth, !nextFree.isZero());
156 APInt ceilDiv = (nextFree - bias).udiv(alignment) + bias;
157 APInt nextFreeAligned = ceilDiv * alignment;
159 memBlock.nextFree = nextFreeAligned + size;
160 if (memBlock.nextFree.ugt(memBlock.maxAddr)) {
161 mem->emitError(
"memory block not large enough to fit all allocations");
162 return signalPassFailure();
165 ++numMemoriesAllocated;
167 IRRewriter builder(mem);
168 builder.replaceOpWithNewOp<ConstantOp>(
169 mem, ImmediateAttr::get(builder.getContext(), nextFreeAligned));
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.