11 #include "mlir/Dialect/SCF/IR/SCF.h"
12 #include "mlir/IR/Dominance.h"
13 #include "mlir/IR/ImplicitLocOpBuilder.h"
14 #include "llvm/ADT/PointerIntPair.h"
15 #include "llvm/ADT/TypeSwitch.h"
16 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "arc-legalize-state-update"
22 #define GEN_PASS_DEF_LEGALIZESTATEUPDATE
23 #include "circt/Dialect/Arc/ArcPasses.h.inc"
28 using namespace circt;
33 if (isa<StateReadOp, StateWriteOp, CallOpInterface, CallableOpInterface>(op))
35 if (op->getNumRegions() > 0)
46 enum class AccessType { Read = 0, Write = 1 };
49 using Access = llvm::PointerIntPair<Value, 1, AccessType>;
55 struct BlockAccesses {
56 BlockAccesses(Block *block) : block(block) {}
61 OpAccesses *parent =
nullptr;
63 SmallPtrSet<Access, 1> argAccesses;
66 SmallPtrSet<Access, 1> aboveAccesses;
71 OpAccesses(Operation *op) : op(op) {}
76 BlockAccesses *parent =
nullptr;
78 SmallPtrSet<OpAccesses *, 1> callers;
80 SmallPtrSet<Access, 1> accesses;
86 struct AccessAnalysis {
87 LogicalResult analyze(Operation *op);
88 OpAccesses *lookup(Operation *op);
89 BlockAccesses *lookup(Block *block);
94 DenseMap<Value, unsigned> stateOrder;
97 SymbolTableCollection symbolTable;
100 llvm::SpecificBumpPtrAllocator<OpAccesses> opAlloc;
101 llvm::SpecificBumpPtrAllocator<BlockAccesses> blockAlloc;
103 DenseMap<Operation *, OpAccesses *> opAccesses;
104 DenseMap<Block *, BlockAccesses *> blockAccesses;
106 SetVector<OpAccesses *> opWorklist;
107 bool anyInvalidStateAccesses =
false;
110 OpAccesses &
get(Operation *op) {
111 auto &slot = opAccesses[op];
113 slot =
new (opAlloc.Allocate()) OpAccesses(op);
118 BlockAccesses &
get(Block *block) {
119 auto &slot = blockAccesses[block];
121 slot =
new (blockAlloc.Allocate()) BlockAccesses(block);
126 void addOpAccess(OpAccesses &op, Access access);
127 void addBlockAccess(BlockAccesses &block, Access access);
132 LogicalResult AccessAnalysis::analyze(Operation *op) {
133 LLVM_DEBUG(
llvm::dbgs() <<
"Analyzing accesses in " << op->getName() <<
"\n");
136 llvm::SmallSetVector<OpAccesses *, 16> initWorklist;
137 initWorklist.insert(&
get(op));
138 while (!initWorklist.empty()) {
139 OpAccesses &opNode = *initWorklist.pop_back_val();
142 for (
auto ®ion : opNode.op->getRegions()) {
143 for (
auto &block : region) {
144 BlockAccesses &blockNode =
get(&block);
145 blockNode.parent = &opNode;
146 for (
auto &subOp : block) {
149 OpAccesses &subOpNode =
get(&subOp);
150 if (!subOp.hasTrait<OpTrait::IsIsolatedFromAbove>()) {
151 subOpNode.parent = &blockNode;
153 initWorklist.insert(&subOpNode);
159 if (
auto callOp = dyn_cast<CallOpInterface>(opNode.op))
160 if (
auto *calleeOp = callOp.resolveCallable(&symbolTable))
161 get(calleeOp).callers.insert(&opNode);
164 if (
auto readOp = dyn_cast<StateReadOp>(opNode.op))
165 addOpAccess(opNode, Access(readOp.getState(), AccessType::Read));
166 else if (
auto writeOp = dyn_cast<StateWriteOp>(opNode.op))
167 addOpAccess(opNode, Access(writeOp.getState(), AccessType::Write));
169 LLVM_DEBUG(
llvm::dbgs() <<
"- Prepared " << blockAccesses.size()
170 <<
" block and " << opAccesses.size()
171 <<
" op lattice nodes\n");
172 LLVM_DEBUG(
llvm::dbgs() <<
"- Worklist has " << opWorklist.size()
173 <<
" initial ops\n");
176 while (!opWorklist.empty()) {
177 if (anyInvalidStateAccesses)
179 auto &opNode = *opWorklist.pop_back_val();
180 if (opNode.callers.empty())
182 auto calleeOp = dyn_cast<CallableOpInterface>(opNode.op);
184 return opNode.op->emitOpError(
185 "does not implement CallableOpInterface but has callers");
186 LLVM_DEBUG(
llvm::dbgs() <<
"- Updating callable " << opNode.op->getName()
187 <<
" " << opNode.op->getAttr(
"sym_name") <<
"\n");
189 auto &calleeRegion = *calleeOp.getCallableRegion();
190 auto *blockNode = lookup(&calleeRegion.front());
193 auto calleeArgs = blockNode->block->getArguments();
195 for (
auto *callOpNode : opNode.callers) {
196 LLVM_DEBUG(
llvm::dbgs() <<
" - Updating " << *callOpNode->op <<
"\n");
197 auto callArgs = cast<CallOpInterface>(callOpNode->op).getArgOperands();
198 for (
auto [calleeArg, callArg] : llvm::zip(calleeArgs, callArgs)) {
199 if (blockNode->argAccesses.contains({calleeArg, AccessType::Read}))
200 addOpAccess(*callOpNode, {callArg, AccessType::Read});
201 if (blockNode->argAccesses.contains({calleeArg, AccessType::Write}))
202 addOpAccess(*callOpNode, {callArg, AccessType::Write});
207 return failure(anyInvalidStateAccesses);
210 OpAccesses *AccessAnalysis::lookup(Operation *op) {
211 return opAccesses.lookup(op);
214 BlockAccesses *AccessAnalysis::lookup(Block *block) {
215 return blockAccesses.lookup(block);
219 void AccessAnalysis::addOpAccess(OpAccesses &op, Access access) {
223 auto *defOp = access.getPointer().getDefiningOp();
224 if (defOp && !isa<AllocStateOp, RootInputOp, RootOutputOp>(defOp)) {
225 auto d = op.op->emitOpError(
"accesses non-trivial state value defined by `")
227 <<
"`; only block arguments and `arc.alloc_state` results are "
229 d.attachNote(defOp->getLoc()) <<
"state defined here";
230 anyInvalidStateAccesses =
true;
236 if (isa<PassThroughOp>(op.op))
241 if (op.accesses.insert(access).second && op.parent) {
242 stateOrder.insert({access.getPointer(), stateOrder.size()});
243 addBlockAccess(*op.parent, access);
247 void AccessAnalysis::addBlockAccess(BlockAccesses &block, Access access) {
248 Value
value = access.getPointer();
252 if (
value.getParentBlock() != block.block) {
253 if (block.aboveAccesses.insert(access).second)
254 addOpAccess(*block.parent, access);
260 if (
auto blockArg = dyn_cast<BlockArgument>(
value)) {
261 assert(blockArg.getOwner() == block.block);
262 if (!block.argAccesses.insert(access).second)
267 opWorklist.insert(block.parent);
278 Legalizer(AccessAnalysis &analysis) : analysis(analysis) {}
279 LogicalResult run(MutableArrayRef<Region> regions);
280 LogicalResult visitBlock(Block *block);
282 AccessAnalysis &analysis;
284 unsigned numLegalizedWrites = 0;
285 unsigned numUpdatedReads = 0;
290 DenseMap<Value, Value> legalizedStates;
294 LogicalResult Legalizer::run(MutableArrayRef<Region> regions) {
295 for (
auto ®ion : regions)
296 for (
auto &block : region)
297 if (failed(visitBlock(&block)))
299 assert(legalizedStates.empty() &&
"should be balanced within block");
303 LogicalResult Legalizer::visitBlock(Block *block) {
306 SmallPtrSet<Value, 4> readStates;
307 DenseMap<Value, Operation *> illegallyWrittenStates;
308 for (Operation &op : llvm::reverse(*block)) {
309 const auto *accesses = analysis.lookup(&op);
315 SmallVector<Value, 1> affectedStates;
316 for (
auto access : accesses->accesses)
317 if (access.getInt() == AccessType::Write)
318 if (readStates.contains(access.getPointer()))
319 illegallyWrittenStates[access.getPointer()] = &op;
325 for (
auto access : accesses->accesses)
326 if (access.getInt() == AccessType::Read)
327 readStates.insert(access.getPointer());
334 DenseMap<Operation *, SmallVector<Value, 1>> illegalWrites;
335 for (
auto [state, op] : illegallyWrittenStates)
336 if (!legalizedStates.count(state))
337 illegalWrites[op].push_back(state);
342 SmallVector<Value> locallyLegalizedStates;
344 auto handleIllegalWrites =
345 [&](Operation *op, SmallVector<Value, 1> &states) -> LogicalResult {
346 LLVM_DEBUG(
llvm::dbgs() <<
"Visiting illegal " << op->getName() <<
"\n");
351 llvm::sort(states, [&](Value a, Value b) {
352 return analysis.stateOrder.lookup(a) < analysis.stateOrder.lookup(b);
356 for (
auto state : states) {
357 LLVM_DEBUG(
llvm::dbgs() <<
"- Legalizing " << state <<
"\n");
362 auto storage = TypeSwitch<Operation *, Value>(state.getDefiningOp())
363 .Case<AllocStateOp, RootInputOp, RootOutputOp>(
364 [&](
auto allocOp) {
return allocOp.getStorage(); })
365 .Default([](
auto) {
return Value{}; });
369 "cannot find storage pointer to allocate temporary into");
375 ++numLegalizedWrites;
376 ImplicitLocOpBuilder
builder(state.getLoc(), op);
378 builder.create<AllocStateOp>(state.getType(), storage,
nullptr);
379 auto stateValue =
builder.create<StateReadOp>(state);
380 builder.create<StateWriteOp>(tmpState, stateValue, Value{});
381 locallyLegalizedStates.push_back(state);
382 legalizedStates.insert({state, tmpState});
387 for (Operation &op : *block) {
389 if (
auto it = illegalWrites.find(&op); it != illegalWrites.end())
390 if (failed(handleIllegalWrites(&op, it->second)))
401 const auto *accesses = analysis.lookup(&op);
402 for (
auto &operand : op.getOpOperands()) {
404 accesses->accesses.contains({operand.get(), AccessType::Read}) &&
405 accesses->accesses.contains({operand.get(), AccessType::Write})) {
406 auto d = op.emitWarning(
"operation reads and writes state; "
407 "legalization may be insufficient");
409 <<
"state update legalization does not properly handle operations "
410 "that both read and write states at the same time; runtime data "
411 "races between the read and write behavior are possible";
412 d.attachNote(operand.get().getLoc()) <<
"state defined here:";
415 !accesses->accesses.contains({operand.get(), AccessType::Write})) {
416 if (
auto tmpState = legalizedStates.lookup(operand.get())) {
417 operand.set(tmpState);
422 for (
auto ®ion : op.getRegions())
423 for (
auto &block : region)
424 if (failed(visitBlock(&block)))
430 for (
auto state : locallyLegalizedStates)
431 legalizedStates.erase(state);
436 Operation *write, Operation **writeAncestor, Operation *read,
437 Operation **readAncestor, DominanceInfo *domInfo) {
438 Block *commonDominator =
439 domInfo->findNearestCommonDominator(write->getBlock(), read->getBlock());
440 if (!commonDominator)
441 return write->emitOpError(
442 "cannot find a common dominator block with all read operations");
446 Operation *writeParent = write;
447 while (writeParent->getBlock() != commonDominator) {
448 if (!isa<scf::IfOp, ClockTreeOp>(writeParent->getParentOp()))
449 return write->emitOpError(
"memory write operations in arbitrarily nested "
450 "regions not supported");
451 writeParent = writeParent->getParentOp();
453 Operation *readParent = read;
454 while (readParent->getBlock() != commonDominator)
455 readParent = readParent->getParentOp();
457 *writeAncestor = writeParent;
458 *readAncestor = readParent;
464 DominanceInfo *domInfo) {
466 DenseMap<Value, SetVector<Operation *>> readOps;
467 auto result = region.walk([&](Operation *op) {
468 if (isa<MemoryWriteOp>(op))
469 return WalkResult::advance();
470 SmallVector<Value> memoriesReadFrom;
471 if (
auto readOp = dyn_cast<MemoryReadOp>(op)) {
472 memoriesReadFrom.push_back(readOp.getMemory());
474 for (
auto operand : op->getOperands())
475 if (isa<MemoryType>(operand.getType()))
476 memoriesReadFrom.push_back(operand);
478 for (
auto memVal : memoriesReadFrom) {
479 if (!memories.contains(memVal))
480 return op->emitOpError(
"uses memory value not directly defined by a "
481 "arc.alloc_memory operation"),
482 WalkResult::interrupt();
483 readOps[memVal].insert(op);
486 return WalkResult::advance();
489 if (result.wasInterrupted())
493 SmallVector<MemoryWriteOp> writes;
494 region.walk([&](MemoryWriteOp writeOp) { writes.push_back(writeOp); });
497 for (
auto writeOp : writes) {
498 if (!memories.contains(writeOp.getMemory()))
499 return writeOp->emitOpError(
"uses memory value not directly defined by a "
500 "arc.alloc_memory operation");
501 for (
auto *readOp : readOps[writeOp.getMemory()]) {
517 Operation *readAncestor, *writeAncestor;
519 writeOp, &writeAncestor, readOp, &readAncestor, domInfo)))
523 if (writeAncestor->isBeforeInBlock(readAncestor))
524 writeAncestor->moveAfter(readAncestor);
529 for (
auto writeOp : writes) {
530 for (
auto *readOp : readOps[writeOp.getMemory()]) {
531 Operation *readAncestor, *writeAncestor;
533 writeOp, &writeAncestor, readOp, &readAncestor, domInfo)))
536 if (writeAncestor->isBeforeInBlock(readAncestor))
538 ->emitOpError(
"could not be moved to be after all reads to "
540 .attachNote(readOp->getLoc())
541 <<
"could not be moved after this read";
553 struct LegalizeStateUpdatePass
554 :
public arc::impl::LegalizeStateUpdateBase<LegalizeStateUpdatePass> {
555 LegalizeStateUpdatePass() =
default;
556 LegalizeStateUpdatePass(
const LegalizeStateUpdatePass &pass)
557 : LegalizeStateUpdatePass() {}
559 void runOnOperation()
override;
561 Statistic numLegalizedWrites{
562 this,
"legalized-writes",
563 "Writes that required temporary state for later reads"};
564 Statistic numUpdatedReads{
this,
"updated-reads",
"Reads that were updated"};
568 void LegalizeStateUpdatePass::runOnOperation() {
569 auto module = getOperation();
570 auto *domInfo = &getAnalysis<DominanceInfo>();
572 for (
auto model : module.getOps<ModelOp>()) {
573 DenseSet<Value> memories;
574 for (
auto memOp : model.getOps<AllocMemoryOp>())
575 memories.insert(memOp.getResult());
576 for (
auto ct : model.getOps<ClockTreeOp>())
579 return signalPassFailure();
582 AccessAnalysis analysis;
583 if (failed(analysis.analyze(module)))
584 return signalPassFailure();
586 Legalizer legalizer(analysis);
587 if (failed(legalizer.run(module->getRegions())))
588 return signalPassFailure();
589 numLegalizedWrites += legalizer.numLegalizedWrites;
590 numUpdatedReads += legalizer.numUpdatedReads;
594 return std::make_unique<LegalizeStateUpdatePass>();
assert(baseType &&"element must be base type")
static LogicalResult getAncestorOpsInCommonDominatorBlock(Operation *write, Operation **writeAncestor, Operation *read, Operation **readAncestor, DominanceInfo *domInfo)
static LogicalResult moveMemoryWritesAfterLastRead(Region ®ion, const DenseSet< Value > &memories, DominanceInfo *domInfo)
static bool isOpInteresting(Operation *op)
Check if an operation partakes in state accesses.
std::unique_ptr< mlir::Pass > createLegalizeStateUpdatePass()
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
mlir::raw_indented_ostream & dbgs()