CIRCT 20.0.0git
Loading...
Searching...
No Matches
AllocateState.cpp
Go to the documentation of this file.
1//===- AllocateState.cpp --------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/IR/ImplicitLocOpBuilder.h"
12#include "mlir/Pass/Pass.h"
13#include "llvm/Support/Debug.h"
14
15#define DEBUG_TYPE "arc-allocate-state"
16
17namespace circt {
18namespace arc {
19#define GEN_PASS_DEF_ALLOCATESTATE
20#include "circt/Dialect/Arc/ArcPasses.h.inc"
21} // namespace arc
22} // namespace circt
23
24using namespace mlir;
25using namespace circt;
26using namespace arc;
27
28using llvm::SmallMapVector;
29
30//===----------------------------------------------------------------------===//
31// Pass Implementation
32//===----------------------------------------------------------------------===//
33
34namespace {
35struct AllocateStatePass
36 : public arc::impl::AllocateStateBase<AllocateStatePass> {
37 void runOnOperation() override;
38 void allocateBlock(Block *block);
39 void allocateOps(Value storage, Block *block, ArrayRef<Operation *> ops);
40};
41} // namespace
42
43void AllocateStatePass::runOnOperation() {
44 ModelOp modelOp = getOperation();
45 LLVM_DEBUG(llvm::dbgs() << "Allocating state in `" << modelOp.getName()
46 << "`\n");
47
48 // Walk the blocks from innermost to outermost and group all state allocations
49 // in that block in one larger allocation.
50 modelOp.walk([&](Block *block) { allocateBlock(block); });
51}
52
53void AllocateStatePass::allocateBlock(Block *block) {
54 SmallMapVector<Value, std::vector<Operation *>, 1> opsByStorage;
55
56 // Group operations by their storage. There is generally just one storage,
57 // passed into the model as a block argument.
58 for (auto &op : *block) {
59 if (isa<AllocStateOp, RootInputOp, RootOutputOp, AllocMemoryOp,
60 AllocStorageOp>(&op))
61 opsByStorage[op.getOperand(0)].push_back(&op);
62 }
63 LLVM_DEBUG(llvm::dbgs() << "- Visiting block in "
64 << block->getParentOp()->getName() << "\n");
65
66 // Actually allocate each operation.
67 for (auto &[storage, ops] : opsByStorage)
68 allocateOps(storage, block, ops);
69}
70
71void AllocateStatePass::allocateOps(Value storage, Block *block,
72 ArrayRef<Operation *> ops) {
73 SmallVector<std::tuple<Value, Value, IntegerAttr>> gettersToCreate;
74
75 // Helper function to allocate storage aligned to its own size, or 8 bytes at
76 // most.
77 unsigned currentByte = 0;
78 auto allocBytes = [&](unsigned numBytes) {
79 currentByte = llvm::alignToPowerOf2(
80 currentByte, llvm::bit_ceil(std::min(numBytes, 16U)));
81 unsigned offset = currentByte;
82 currentByte += numBytes;
83 return offset;
84 };
85
86 // Allocate storage for the operations.
87 OpBuilder builder(block->getParentOp());
88 for (auto *op : ops) {
89 if (isa<AllocStateOp, RootInputOp, RootOutputOp>(op)) {
90 auto result = op->getResult(0);
91 auto storage = op->getOperand(0);
92 unsigned numBytes = cast<StateType>(result.getType()).getByteWidth();
93 auto offset = builder.getI32IntegerAttr(allocBytes(numBytes));
94 op->setAttr("offset", offset);
95 gettersToCreate.emplace_back(result, storage, offset);
96 continue;
97 }
98
99 if (auto memOp = dyn_cast<AllocMemoryOp>(op)) {
100 auto memType = memOp.getType();
101 unsigned stride = memType.getStride();
102 unsigned numBytes = memType.getNumWords() * stride;
103 auto offset = builder.getI32IntegerAttr(allocBytes(numBytes));
104 op->setAttr("offset", offset);
105 op->setAttr("stride", builder.getI32IntegerAttr(stride));
106 gettersToCreate.emplace_back(memOp, memOp.getStorage(), offset);
107 continue;
108 }
109
110 if (auto allocStorageOp = dyn_cast<AllocStorageOp>(op)) {
111 auto offset = builder.getI32IntegerAttr(
112 allocBytes(allocStorageOp.getType().getSize()));
113 allocStorageOp.setOffsetAttr(offset);
114 gettersToCreate.emplace_back(allocStorageOp, allocStorageOp.getInput(),
115 offset);
116 continue;
117 }
118
119 assert("unsupported op for allocation" && false);
120 }
121
122 // For every user of the alloc op, create a local `StorageGetOp`.
123 // First, create an ordering of operations to avoid a very expensive
124 // combination of isBeforeInBlock and moveBefore calls (which can be O(n²))
125 DenseMap<Operation *, unsigned> opOrder;
126 block->walk([&](Operation *op) { opOrder.insert({op, opOrder.size()}); });
127 SmallVector<StorageGetOp> getters;
128 for (auto [result, storage, offset] : gettersToCreate) {
130 for (auto *user : llvm::make_early_inc_range(result.getUsers())) {
131 auto &getter = getterForBlock[user->getBlock()];
132 // Create a local getter in front of each user, except for
133 // `AllocStorageOp`s, for which we create a block-wider accessor.
134 auto userOrder = opOrder.lookup(user);
135 if (!getter || !result.getDefiningOp<AllocStorageOp>()) {
136 ImplicitLocOpBuilder builder(result.getLoc(), user);
137 getter =
138 builder.create<StorageGetOp>(result.getType(), storage, offset);
139 getters.push_back(getter);
140 opOrder[getter] = userOrder;
141 } else if (userOrder < opOrder.lookup(getter)) {
142 getter->moveBefore(user);
143 opOrder[getter] = userOrder;
144 }
145 user->replaceUsesOfWith(result, getter);
146 }
147 }
148
149 // Create the substorage accessor at the beginning of the block.
150 Operation *storageOwner = storage.getDefiningOp();
151 if (!storageOwner)
152 storageOwner = cast<BlockArgument>(storage).getOwner()->getParentOp();
153
154 if (storageOwner->isProperAncestor(block->getParentOp())) {
155 auto substorage = builder.create<AllocStorageOp>(
156 block->getParentOp()->getLoc(),
157 StorageType::get(&getContext(), currentByte), storage);
158 for (auto *op : ops)
159 op->replaceUsesOfWith(storage, substorage);
160 for (auto op : getters)
161 op->replaceUsesOfWith(storage, substorage);
162 } else {
163 storage.setType(StorageType::get(&getContext(), currentByte));
164 }
165}
166
167std::unique_ptr<Pass> arc::createAllocateStatePass() {
168 return std::make_unique<AllocateStatePass>();
169}
assert(baseType &&"element must be base type")
std::unique_ptr< mlir::Pass > createAllocateStatePass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.