11 #include "mlir/IR/ImplicitLocOpBuilder.h"
12 #include "mlir/Pass/Pass.h"
13 #include "llvm/Support/Debug.h"
15 #define DEBUG_TYPE "arc-allocate-state"
19 #define GEN_PASS_DEF_ALLOCATESTATE
20 #include "circt/Dialect/Arc/ArcPasses.h.inc"
25 using namespace circt;
28 using llvm::SmallMapVector;
35 struct AllocateStatePass
36 :
public arc::impl::AllocateStateBase<AllocateStatePass> {
37 void runOnOperation()
override;
38 void allocateBlock(
Block *block);
39 void allocateOps(Value storage,
Block *block, ArrayRef<Operation *> ops);
43 void AllocateStatePass::runOnOperation() {
44 ModelOp modelOp = getOperation();
45 LLVM_DEBUG(
llvm::dbgs() <<
"Allocating state in `" << modelOp.getName()
50 modelOp.walk([&](Block *block) { allocateBlock(block); });
53 void AllocateStatePass::allocateBlock(Block *block) {
54 SmallMapVector<Value, std::vector<Operation *>, 1> opsByStorage;
58 for (
auto &op : *block) {
59 if (isa<AllocStateOp, RootInputOp, RootOutputOp, AllocMemoryOp,
61 opsByStorage[op.getOperand(0)].push_back(&op);
63 LLVM_DEBUG(
llvm::dbgs() <<
"- Visiting block in "
64 << block->getParentOp()->getName() <<
"\n");
67 for (
auto &[storage, ops] : opsByStorage)
68 allocateOps(storage, block, ops);
71 void AllocateStatePass::allocateOps(Value storage, Block *block,
72 ArrayRef<Operation *> ops) {
73 SmallVector<std::tuple<Value, Value, IntegerAttr>> gettersToCreate;
77 unsigned currentByte = 0;
78 auto allocBytes = [&](
unsigned numBytes) {
79 currentByte = llvm::alignToPowerOf2(
80 currentByte, llvm::bit_ceil(std::min(numBytes, 16U)));
81 unsigned offset = currentByte;
82 currentByte += numBytes;
87 OpBuilder
builder(block->getParentOp());
88 for (
auto *op : ops) {
89 if (isa<AllocStateOp, RootInputOp, RootOutputOp>(op)) {
90 auto result = op->getResult(0);
91 auto storage = op->getOperand(0);
92 unsigned numBytes = result.getType().cast<StateType>().getByteWidth();
93 auto offset =
builder.getI32IntegerAttr(allocBytes(numBytes));
94 op->setAttr(
"offset", offset);
95 gettersToCreate.emplace_back(result, storage, offset);
99 if (
auto memOp = dyn_cast<AllocMemoryOp>(op)) {
100 auto memType = memOp.getType();
101 unsigned stride = memType.getStride();
102 unsigned numBytes = memType.getNumWords() * stride;
103 auto offset =
builder.getI32IntegerAttr(allocBytes(numBytes));
104 op->setAttr(
"offset", offset);
105 op->setAttr(
"stride",
builder.getI32IntegerAttr(stride));
106 gettersToCreate.emplace_back(memOp, memOp.getStorage(), offset);
110 if (
auto allocStorageOp = dyn_cast<AllocStorageOp>(op)) {
111 auto offset =
builder.getI32IntegerAttr(
112 allocBytes(allocStorageOp.getType().getSize()));
113 allocStorageOp.setOffsetAttr(offset);
114 gettersToCreate.emplace_back(allocStorageOp, allocStorageOp.getInput(),
119 assert(
"unsupported op for allocation" &&
false);
125 DenseMap<Operation *, unsigned> opOrder;
126 block->walk([&](Operation *op) { opOrder.insert({op, opOrder.size()}); });
127 SmallVector<StorageGetOp> getters;
128 for (
auto [result, storage, offset] : gettersToCreate) {
130 for (
auto *user : llvm::make_early_inc_range(result.getUsers())) {
131 auto &getter = getterForBlock[user->getBlock()];
134 auto userOrder = opOrder.lookup(user);
135 if (!getter || !result.getDefiningOp<AllocStorageOp>()) {
136 ImplicitLocOpBuilder
builder(result.getLoc(), user);
138 builder.create<StorageGetOp>(result.getType(), storage, offset);
139 getters.push_back(getter);
140 opOrder[getter] = userOrder;
141 }
else if (userOrder < opOrder.lookup(getter)) {
142 getter->moveBefore(user);
143 opOrder[getter] = userOrder;
145 user->replaceUsesOfWith(result, getter);
150 Operation *storageOwner = storage.getDefiningOp();
152 storageOwner = storage.cast<BlockArgument>().getOwner()->getParentOp();
154 if (storageOwner->isProperAncestor(block->getParentOp())) {
155 auto substorage =
builder.create<AllocStorageOp>(
156 block->getParentOp()->getLoc(),
159 op->replaceUsesOfWith(storage, substorage);
160 for (
auto op : getters)
161 op->replaceUsesOfWith(storage, substorage);
168 return std::make_unique<AllocateStatePass>();
assert(baseType &&"element must be base type")
std::unique_ptr< mlir::Pass > createAllocateStatePass()
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
mlir::raw_indented_ostream & dbgs()