12#include "mlir/Dialect/Index/IR/IndexOps.h"
13#include "mlir/IR/PatternMatch.h"
17#define GEN_PASS_DEF_MEMORYALLOCATIONPASS
18#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
31struct AllocationInfo {
37LogicalResult adjustAPIntWidth(APInt &value,
unsigned targetBitWidth,
39 if (value.getBitWidth() > targetBitWidth && !value.isIntN(targetBitWidth))
40 return mlir::emitError(
41 loc,
"cannot truncate APInt because value is too big to fit");
43 if (value.getBitWidth() < targetBitWidth) {
44 value = value.zext(targetBitWidth);
48 value = value.trunc(targetBitWidth);
52struct MemoryAllocationPass
53 :
public rtg::impl::MemoryAllocationPassBase<MemoryAllocationPass> {
56 void runOnOperation()
override;
60void MemoryAllocationPass::runOnOperation() {
61 auto testOp = getOperation();
62 DenseMap<Value, AllocationInfo> nextFreeMap;
65 testOp->emitError(
"label mode not yet supported");
66 return signalPassFailure();
70 auto target = testOp.getTargetAttr();
74 SymbolTable table(testOp->getParentOfType<ModuleOp>());
75 auto targetOp = table.lookupNearestSymbolFrom<TargetOp>(testOp, target);
77 for (
auto &op : *targetOp.getBody()) {
78 auto memBlock = dyn_cast<MemoryBlockDeclareOp>(&op);
82 auto &slot = nextFreeMap[memBlock.getResult()];
83 slot.nextFree = memBlock.getBaseAddress();
84 slot.maxAddr = memBlock.getEndAddress();
88 auto targetYields = targetOp.getBody()->getTerminator()->getOperands();
89 auto targetEntries = targetOp.getTarget().getEntries();
90 auto testEntries = testOp.getTargetType().getEntries();
91 auto testArgs = testOp.getBody()->getArguments();
94 for (
auto [testEntry, testArg] :
llvm::zip(testEntries, testArgs)) {
95 while (targetIdx < targetEntries.size() &&
96 targetEntries[targetIdx].name.getValue() < testEntry.name.getValue())
99 if (targetIdx < targetEntries.size() &&
100 targetEntries[targetIdx].name.getValue() == testEntry.name.getValue()) {
101 auto targetYield = targetYields[targetIdx];
102 auto it = nextFreeMap.find(targetYield);
103 if (it != nextFreeMap.end())
104 nextFreeMap[testArg] = it->second;
110 for (
auto &op :
llvm::make_early_inc_range(*testOp.getBody())) {
111 auto mem = dyn_cast<MemoryAllocOp>(&op);
115 auto iter = nextFreeMap.find(mem.getMemoryBlock());
116 if (iter == nextFreeMap.end()) {
117 mem->emitError(
"memory block not found");
118 return signalPassFailure();
121 auto sizeOp = mem.getSize().getDefiningOp<index::ConstantOp>();
123 mem->emitError(
"could not determine memory allocation size");
124 return signalPassFailure();
127 auto alignOp = mem.getAlignment().getDefiningOp<index::ConstantOp>();
129 mem->emitError(
"could not determine memory allocation alignment");
130 return signalPassFailure();
133 APInt size = sizeOp.getValue();
134 APInt alignment = alignOp.getValue();
138 "memory allocation size must be greater than zero (was 0)");
139 return signalPassFailure();
142 if (!alignment.isPowerOf2()) {
143 mem->emitError(
"memory allocation alignment must be a power of two (was ")
144 << alignment.getZExtValue() <<
")";
145 return signalPassFailure();
148 auto &memBlock = iter->getSecond();
149 APInt nextFree = memBlock.nextFree;
150 unsigned bitWidth = nextFree.getBitWidth();
152 if (failed(adjustAPIntWidth(size, bitWidth, mem.getLoc())) ||
153 failed(adjustAPIntWidth(alignment, bitWidth, mem.getLoc())))
154 return signalPassFailure();
157 APInt bias(bitWidth, !nextFree.isZero());
158 APInt ceilDiv = (nextFree - bias).udiv(alignment) + bias;
159 APInt nextFreeAligned = ceilDiv * alignment;
161 memBlock.nextFree = nextFreeAligned + size;
162 if (memBlock.nextFree.ugt(memBlock.maxAddr)) {
163 mem->emitError(
"memory block not large enough to fit all allocations");
164 return signalPassFailure();
167 ++numMemoriesAllocated;
169 IRRewriter builder(mem);
170 builder.replaceOpWithNewOp<ConstantOp>(
171 mem, ImmediateAttr::get(builder.getContext(), nextFreeAligned));
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.