11 #include "mlir/Dialect/SCF/IR/SCF.h"
12 #include "mlir/IR/Dominance.h"
13 #include "mlir/IR/ImplicitLocOpBuilder.h"
14 #include "llvm/ADT/PointerIntPair.h"
15 #include "llvm/ADT/TypeSwitch.h"
16 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "arc-legalize-state-update"
22 #define GEN_PASS_DEF_LEGALIZESTATEUPDATE
23 #include "circt/Dialect/Arc/ArcPasses.h.inc"
28 using namespace circt;
33 if (isa<StateReadOp, StateWriteOp, CallOpInterface, CallableOpInterface>(op))
35 if (op->getNumRegions() > 0)
46 enum class AccessType { Read = 0, Write = 1 };
49 using Access = llvm::PointerIntPair<Value, 1, AccessType>;
55 struct BlockAccesses {
56 BlockAccesses(Block *block) : block(block) {}
61 OpAccesses *parent =
nullptr;
63 SmallPtrSet<Access, 1> argAccesses;
66 SmallPtrSet<Access, 1> aboveAccesses;
71 OpAccesses(Operation *op) : op(op) {}
76 BlockAccesses *parent =
nullptr;
78 SmallPtrSet<OpAccesses *, 1> callers;
80 SmallPtrSet<Access, 1> accesses;
86 struct AccessAnalysis {
87 LogicalResult analyze(Operation *op);
88 OpAccesses *lookup(Operation *op);
89 BlockAccesses *lookup(Block *block);
94 DenseMap<Value, unsigned> stateOrder;
97 SymbolTableCollection symbolTable;
100 llvm::SpecificBumpPtrAllocator<OpAccesses> opAlloc;
101 llvm::SpecificBumpPtrAllocator<BlockAccesses> blockAlloc;
103 DenseMap<Operation *, OpAccesses *> opAccesses;
104 DenseMap<Block *, BlockAccesses *> blockAccesses;
106 SetVector<OpAccesses *> opWorklist;
107 bool anyInvalidStateAccesses =
false;
110 OpAccesses &
get(Operation *op) {
111 auto &slot = opAccesses[op];
113 slot =
new (opAlloc.Allocate()) OpAccesses(op);
118 BlockAccesses &
get(Block *block) {
119 auto &slot = blockAccesses[block];
121 slot =
new (blockAlloc.Allocate()) BlockAccesses(block);
126 void addOpAccess(OpAccesses &op, Access access);
127 void addBlockAccess(BlockAccesses &block, Access access);
132 LogicalResult AccessAnalysis::analyze(Operation *op) {
133 LLVM_DEBUG(llvm::dbgs() <<
"Analyzing accesses in " << op->getName() <<
"\n");
136 llvm::SmallSetVector<OpAccesses *, 16> initWorklist;
137 initWorklist.insert(&
get(op));
138 while (!initWorklist.empty()) {
139 OpAccesses &opNode = *initWorklist.pop_back_val();
142 for (
auto ®ion : opNode.op->getRegions()) {
143 for (
auto &block : region) {
144 BlockAccesses &blockNode =
get(&block);
145 blockNode.parent = &opNode;
146 for (
auto &subOp : block) {
149 OpAccesses &subOpNode =
get(&subOp);
150 if (!subOp.hasTrait<OpTrait::IsIsolatedFromAbove>()) {
151 subOpNode.parent = &blockNode;
153 initWorklist.insert(&subOpNode);
159 if (
auto callOp = dyn_cast<CallOpInterface>(opNode.op))
160 if (
auto *calleeOp = callOp.resolveCallable(&symbolTable))
161 get(calleeOp).callers.insert(&opNode);
164 if (
auto readOp = dyn_cast<StateReadOp>(opNode.op))
165 addOpAccess(opNode, Access(readOp.getState(), AccessType::Read));
166 else if (
auto writeOp = dyn_cast<StateWriteOp>(opNode.op))
167 addOpAccess(opNode, Access(writeOp.getState(), AccessType::Write));
169 LLVM_DEBUG(llvm::dbgs() <<
"- Prepared " << blockAccesses.size()
170 <<
" block and " << opAccesses.size()
171 <<
" op lattice nodes\n");
172 LLVM_DEBUG(llvm::dbgs() <<
"- Worklist has " << opWorklist.size()
173 <<
" initial ops\n");
176 while (!opWorklist.empty()) {
177 if (anyInvalidStateAccesses)
179 auto &opNode = *opWorklist.pop_back_val();
180 if (opNode.callers.empty())
182 auto calleeOp = dyn_cast<CallableOpInterface>(opNode.op);
184 return opNode.op->emitOpError(
185 "does not implement CallableOpInterface but has callers");
186 LLVM_DEBUG(llvm::dbgs() <<
"- Updating callable " << opNode.op->getName()
187 <<
" " << opNode.op->getAttr(
"sym_name") <<
"\n");
189 auto &calleeRegion = *calleeOp.getCallableRegion();
190 auto *blockNode = lookup(&calleeRegion.front());
193 auto calleeArgs = blockNode->block->getArguments();
195 for (
auto *callOpNode : opNode.callers) {
196 LLVM_DEBUG(llvm::dbgs() <<
" - Updating " << *callOpNode->op <<
"\n");
197 auto callArgs = cast<CallOpInterface>(callOpNode->op).getArgOperands();
198 for (
auto [calleeArg, callArg] : llvm::zip(calleeArgs, callArgs)) {
199 if (blockNode->argAccesses.contains({calleeArg, AccessType::Read}))
200 addOpAccess(*callOpNode, {callArg, AccessType::Read});
201 if (blockNode->argAccesses.contains({calleeArg, AccessType::Write}))
202 addOpAccess(*callOpNode, {callArg, AccessType::Write});
207 return failure(anyInvalidStateAccesses);
210 OpAccesses *AccessAnalysis::lookup(Operation *op) {
211 return opAccesses.lookup(op);
214 BlockAccesses *AccessAnalysis::lookup(Block *block) {
215 return blockAccesses.lookup(block);
219 void AccessAnalysis::addOpAccess(OpAccesses &op, Access access) {
223 auto *defOp = access.getPointer().getDefiningOp();
224 if (defOp && !isa<AllocStateOp, RootInputOp, RootOutputOp>(defOp)) {
225 auto d = op.op->emitOpError(
"accesses non-trivial state value defined by `")
227 <<
"`; only block arguments and `arc.alloc_state` results are "
229 d.attachNote(defOp->getLoc()) <<
"state defined here";
230 anyInvalidStateAccesses =
true;
236 if (isa<PassThroughOp>(op.op))
241 if (op.accesses.insert(access).second && op.parent) {
242 stateOrder.insert({access.getPointer(), stateOrder.size()});
243 addBlockAccess(*op.parent, access);
247 void AccessAnalysis::addBlockAccess(BlockAccesses &block, Access access) {
248 Value value = access.getPointer();
252 if (value.getParentBlock() != block.block) {
253 if (block.aboveAccesses.insert(access).second)
254 addOpAccess(*block.parent, access);
260 if (
auto blockArg = dyn_cast<BlockArgument>(value)) {
261 assert(blockArg.getOwner() == block.block);
262 if (!block.argAccesses.insert(access).second)
267 opWorklist.insert(block.parent);
278 Legalizer(AccessAnalysis &analysis) : analysis(analysis) {}
279 LogicalResult run(MutableArrayRef<Region> regions);
280 LogicalResult visitBlock(Block *block);
282 AccessAnalysis &analysis;
284 unsigned numLegalizedWrites = 0;
285 unsigned numUpdatedReads = 0;
290 DenseMap<Value, Value> legalizedStates;
294 LogicalResult Legalizer::run(MutableArrayRef<Region> regions) {
295 for (
auto ®ion : regions)
296 for (
auto &block : region)
297 if (failed(visitBlock(&block)))
299 assert(legalizedStates.empty() &&
"should be balanced within block");
303 LogicalResult Legalizer::visitBlock(Block *block) {
306 SmallPtrSet<Value, 4> readStates;
307 DenseMap<Value, Operation *> illegallyWrittenStates;
308 for (Operation &op : llvm::reverse(*block)) {
309 const auto *accesses = analysis.lookup(&op);
315 SmallVector<Value, 1> affectedStates;
316 for (
auto access : accesses->accesses)
317 if (access.getInt() == AccessType::Write)
318 if (readStates.contains(access.getPointer()))
319 illegallyWrittenStates[access.getPointer()] = &op;
325 for (
auto access : accesses->accesses)
326 if (access.getInt() == AccessType::Read)
327 readStates.insert(access.getPointer());
334 DenseMap<Operation *, SmallVector<Value, 1>> illegalWrites;
335 for (
auto [state, op] : illegallyWrittenStates)
336 if (!legalizedStates.count(state))
337 illegalWrites[op].push_back(state);
342 SmallVector<Value> locallyLegalizedStates;
344 auto handleIllegalWrites =
345 [&](Operation *op, SmallVector<Value, 1> &states) -> LogicalResult {
346 LLVM_DEBUG(llvm::dbgs() <<
"Visiting illegal " << op->getName() <<
"\n");
351 llvm::sort(states, [&](Value a, Value b) {
352 return analysis.stateOrder.lookup(a) < analysis.stateOrder.lookup(b);
356 for (
auto state : states) {
357 LLVM_DEBUG(llvm::dbgs() <<
"- Legalizing " << state <<
"\n");
362 auto storage = TypeSwitch<Operation *, Value>(state.getDefiningOp())
363 .Case<AllocStateOp, RootInputOp, RootOutputOp>(
364 [&](
auto allocOp) {
return allocOp.getStorage(); })
365 .Default([](
auto) {
return Value{}; });
369 "cannot find storage pointer to allocate temporary into");
375 ++numLegalizedWrites;
376 ImplicitLocOpBuilder
builder(state.getLoc(), op);
378 builder.create<AllocStateOp>(state.getType(), storage,
nullptr);
379 auto stateValue =
builder.create<StateReadOp>(state);
380 builder.create<StateWriteOp>(tmpState, stateValue, Value{});
381 locallyLegalizedStates.push_back(state);
382 legalizedStates.insert({state, tmpState});
387 for (Operation &op : *block) {
389 if (
auto it = illegalWrites.find(&op); it != illegalWrites.end())
390 if (failed(handleIllegalWrites(&op, it->second)))
401 const auto *accesses = analysis.lookup(&op);
402 for (
auto &operand : op.getOpOperands()) {
404 accesses->accesses.contains({operand.get(), AccessType::Read}) &&
405 accesses->accesses.contains({operand.get(), AccessType::Write})) {
406 auto d = op.emitWarning(
"operation reads and writes state; "
407 "legalization may be insufficient");
409 <<
"state update legalization does not properly handle operations "
410 "that both read and write states at the same time; runtime data "
411 "races between the read and write behavior are possible";
412 d.attachNote(operand.get().getLoc()) <<
"state defined here:";
415 !accesses->accesses.contains({operand.get(), AccessType::Write})) {
416 if (
auto tmpState = legalizedStates.lookup(operand.get())) {
417 operand.set(tmpState);
422 for (
auto ®ion : op.getRegions())
423 for (
auto &block : region)
424 if (failed(visitBlock(&block)))
430 for (
auto state : locallyLegalizedStates)
431 legalizedStates.erase(state);
436 Operation *write, Operation **writeAncestor, Operation *read,
437 Operation **readAncestor, DominanceInfo *domInfo) {
438 Block *commonDominator =
439 domInfo->findNearestCommonDominator(write->getBlock(), read->getBlock());
440 if (!commonDominator)
441 return write->emitOpError(
442 "cannot find a common dominator block with all read operations");
446 Operation *writeParent = write;
447 while (writeParent->getBlock() != commonDominator) {
448 if (!isa<scf::IfOp, ClockTreeOp>(writeParent->getParentOp()))
449 return write->emitOpError(
"memory write operations in arbitrarily nested "
450 "regions not supported");
451 writeParent = writeParent->getParentOp();
453 Operation *readParent = read;
454 while (readParent->getBlock() != commonDominator)
455 readParent = readParent->getParentOp();
457 *writeAncestor = writeParent;
458 *readAncestor = readParent;
464 DominanceInfo *domInfo) {
466 DenseMap<Value, SetVector<Operation *>> readOps;
467 auto result = region.walk([&](Operation *op) {
468 if (isa<MemoryWriteOp>(op))
469 return WalkResult::advance();
470 SmallVector<Value> memoriesReadFrom;
471 if (
auto readOp = dyn_cast<MemoryReadOp>(op)) {
472 memoriesReadFrom.push_back(readOp.getMemory());
474 for (
auto operand : op->getOperands())
475 if (isa<MemoryType>(operand.getType()))
476 memoriesReadFrom.push_back(operand);
478 for (
auto memVal : memoriesReadFrom) {
479 if (!memories.contains(memVal))
480 return op->emitOpError(
"uses memory value not directly defined by a "
481 "arc.alloc_memory operation"),
482 WalkResult::interrupt();
483 readOps[memVal].insert(op);
486 return WalkResult::advance();
489 if (result.wasInterrupted())
493 SmallVector<MemoryWriteOp> writes;
494 region.walk([&](MemoryWriteOp writeOp) { writes.push_back(writeOp); });
497 for (
auto writeOp : writes) {
498 if (!memories.contains(writeOp.getMemory()))
499 return writeOp->emitOpError(
"uses memory value not directly defined by a "
500 "arc.alloc_memory operation");
501 for (
auto *readOp : readOps[writeOp.getMemory()]) {
517 Operation *readAncestor, *writeAncestor;
519 writeOp, &writeAncestor, readOp, &readAncestor, domInfo)))
523 if (writeAncestor->isBeforeInBlock(readAncestor))
524 writeAncestor->moveAfter(readAncestor);
529 for (
auto writeOp : writes) {
530 for (
auto *readOp : readOps[writeOp.getMemory()]) {
531 Operation *readAncestor, *writeAncestor;
533 writeOp, &writeAncestor, readOp, &readAncestor, domInfo)))
536 if (writeAncestor->isBeforeInBlock(readAncestor))
538 ->emitOpError(
"could not be moved to be after all reads to "
540 .attachNote(readOp->getLoc())
541 <<
"could not be moved after this read";
553 struct LegalizeStateUpdatePass
554 :
public arc::impl::LegalizeStateUpdateBase<LegalizeStateUpdatePass> {
555 LegalizeStateUpdatePass() =
default;
556 LegalizeStateUpdatePass(
const LegalizeStateUpdatePass &pass)
557 : LegalizeStateUpdatePass() {}
559 void runOnOperation()
override;
561 Statistic numLegalizedWrites{
562 this,
"legalized-writes",
563 "Writes that required temporary state for later reads"};
564 Statistic numUpdatedReads{
this,
"updated-reads",
"Reads that were updated"};
568 void LegalizeStateUpdatePass::runOnOperation() {
569 auto module = getOperation();
570 auto *domInfo = &getAnalysis<DominanceInfo>();
572 for (
auto model : module.getOps<ModelOp>()) {
573 DenseSet<Value> memories;
574 for (
auto memOp : model.getOps<AllocMemoryOp>())
575 memories.insert(memOp.getResult());
576 for (
auto ct : model.getOps<ClockTreeOp>())
579 return signalPassFailure();
582 AccessAnalysis analysis;
583 if (failed(analysis.analyze(module)))
584 return signalPassFailure();
586 Legalizer legalizer(analysis);
587 if (failed(legalizer.run(module->getRegions())))
588 return signalPassFailure();
589 numLegalizedWrites += legalizer.numLegalizedWrites;
590 numUpdatedReads += legalizer.numUpdatedReads;
594 return std::make_unique<LegalizeStateUpdatePass>();
assert(baseType &&"element must be base type")
static LogicalResult getAncestorOpsInCommonDominatorBlock(Operation *write, Operation **writeAncestor, Operation *read, Operation **readAncestor, DominanceInfo *domInfo)
static LogicalResult moveMemoryWritesAfterLastRead(Region ®ion, const DenseSet< Value > &memories, DominanceInfo *domInfo)
static bool isOpInteresting(Operation *op)
Check if an operation partakes in state accesses.
std::unique_ptr< mlir::Pass > createLegalizeStateUpdatePass()
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.