60#include "mlir/Analysis/Liveness.h"
61#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
62#include "mlir/Dialect/Func/IR/FuncOps.h"
63#include "mlir/Dialect/UB/IR/UBOps.h"
64#include "mlir/IR/AttrTypeSubElements.h"
65#include "mlir/IR/Dominance.h"
66#include "mlir/Pass/Pass.h"
67#include "llvm/ADT/TypeSwitch.h"
68#include "llvm/Support/GenericIteratedDominanceFrontier.h"
72#define GEN_PASS_DEF_LOWERCOROUTINESPASS
73#include "circt/Dialect/Arc/ArcPasses.h.inc"
105 SmallPtrSet<Block *, 8> resumeBlockSet;
106 for (
auto &block : region)
107 if (
auto yieldOp = dyn_cast<CoroutineYieldOp>(block.getTerminator()))
108 resumeBlockSet.insert(yieldOp.getDest());
109 SmallVector<Block *> resumeBlocks;
110 for (
auto &block : region)
111 if (resumeBlockSet.contains(&block))
112 resumeBlocks.push_back(&block);
119 ArrayRef<Block *> captureBlocks,
121 DominanceInfo &dominance) {
122 auto *defBlock = value.getParentBlock();
130 auto &domTree = dominance.getDomTree(®ion);
131 llvm::IDFCalculatorBase<Block, false> idfCalculator(domTree);
133 SmallPtrSet<Block *, 8> definingBlocks(captureBlocks.begin(),
134 captureBlocks.end());
135 definingBlocks.insert(defBlock);
136 idfCalculator.setDefiningBlocks(definingBlocks);
138 SmallPtrSet<Block *, 16> liveInBlocks;
139 for (
auto &block : region)
140 if (liveness.getLiveness(&block)->isLiveIn(value))
141 liveInBlocks.insert(&block);
142 idfCalculator.setLiveInBlocks(liveInBlocks);
144 SmallVector<Block *> mergeBlocks;
145 idfCalculator.calculate(mergeBlocks);
147 SmallPtrSet<Block *, 16> argBlocks(mergeBlocks.begin(), mergeBlocks.end());
148 argBlocks.insert(captureBlocks.begin(), captureBlocks.end());
156 struct WorklistItem {
157 DominanceInfoNode *domNode;
160 SmallVector<WorklistItem> worklist;
161 worklist.push_back({domTree.getNode(defBlock), value});
163 while (!worklist.empty()) {
164 auto item = worklist.pop_back_val();
165 auto *block = item.domNode->getBlock();
167 if (argBlocks.contains(block))
168 item.reachingDef = block->addArgument(value.getType(), value.getLoc());
171 if (item.reachingDef != value)
172 block->walk([&](Operation *nestedOp) {
173 nestedOp->replaceUsesOfWith(value, item.reachingDef);
178 auto *terminator = block->getTerminator();
179 auto branchOp = dyn_cast<BranchOpInterface>(terminator);
180 for (
auto &blockOperand : terminator->getBlockOperands()) {
181 if (!argBlocks.contains(blockOperand.get()))
184 return terminator->emitOpError()
185 <<
"does not implement `BranchOpInterface`; cannot pass value "
186 "into successor block";
187 branchOp.getSuccessorOperands(blockOperand.getOperandNumber())
188 .
append(item.reachingDef);
191 for (
auto *child : item.domNode->children())
192 worklist.push_back({child, item.reachingDef});
202 ArrayRef<Block *> captureBlocks,
203 Liveness &liveness) {
206 DenseSet<Block *> resumedBlocks(captureBlocks.begin(), captureBlocks.end());
207 SmallVector<Block *> worklist(captureBlocks.begin(), captureBlocks.end());
208 while (!worklist.empty())
209 for (
auto *successor : worklist.pop_back_val()->getSuccessors())
210 if (liveness.getLiveIn(successor).contains(value) &&
211 resumedBlocks.insert(successor).second)
212 worklist.push_back(successor);
215 auto *defOp = value.getDefiningOp();
216 DenseMap<Block *, Value> clones;
217 for (
auto &use : llvm::make_early_inc_range(value.getUses())) {
218 auto *block = region.findAncestorBlockInRegion(*use.getOwner()->getBlock());
219 if (!resumedBlocks.contains(block))
221 auto &clone = clones[block];
223 OpBuilder builder(block, block->begin());
224 clone = builder.clone(*defOp)->getResult(0);
228 if (defOp->use_empty())
236 Region ®ion = defineOp.getBody();
237 if (region.hasOneBlock())
241 if (resumeBlocks.empty())
249 Liveness liveness(defineOp);
250 DominanceInfo dominance(defineOp);
254 SmallVector<Value> values;
255 for (
auto &block : region) {
256 llvm::append_range(values, block.getArguments());
257 for (
auto &op : block)
258 llvm::append_range(values, op.getResults());
261 for (
auto value : values) {
265 SmallVector<Block *> captureBlocks;
266 for (
auto *block : resumeBlocks)
267 if (liveness.getLiveIn(block).contains(value))
268 captureBlocks.push_back(block);
269 if (captureBlocks.empty())
274 auto *defOp = value.getDefiningOp();
275 if (defOp && defOp->hasTrait<OpTrait::ConstantLike>()) {
279 if (failed(
captureValue(value, region, captureBlocks, liveness, dominance)))
294struct CoroutineLowering {
296 CoroutineDefineOp defineOp;
299 SmallVector<Block *> resumeBlocks;
302 SmallVector<std::optional<unsigned>> variantIndices;
311 uint64_t getReturnPC() {
return getHaltPC() - 1; }
313 uint64_t getHaltPC() {
314 return APInt::getAllOnes(pcType.getWidth()).getZExtValue();
318 std::optional<hw::UnionType::FieldInfo> getVariant(
unsigned resumeIndex) {
319 if (
auto fieldIndex = variantIndices[resumeIndex])
320 return cast<hw::UnionType>(stateType).getElements()[*fieldIndex];
331 auto *
context = defineOp.getContext();
332 CoroutineLowering lowering;
333 lowering.defineOp = defineOp;
340 unsigned pcWidth = llvm::Log2_64_Ceil(lowering.resumeBlocks.size() + 3);
341 lowering.pcType = IntegerType::get(
context, pcWidth);
346 unsigned numCoroutineArgs = defineOp.getArgumentTypes().size();
347 SmallVector<hw::UnionType::FieldInfo> variants;
348 for (
auto [index, block] : llvm::enumerate(lowering.resumeBlocks)) {
349 auto persistedArgs = block->getArguments().drop_front(numCoroutineArgs);
350 if (persistedArgs.empty()) {
351 lowering.variantIndices.push_back(std::nullopt);
354 SmallVector<hw::StructType::FieldInfo> fields;
355 for (
auto [fieldIndex, arg] : llvm::enumerate(persistedArgs))
357 {StringAttr::get(
context,
"f" + Twine(fieldIndex)), arg.getType()});
358 lowering.variantIndices.push_back(variants.size());
359 variants.push_back({StringAttr::get(
context,
"r" + Twine(index + 1)),
360 hw::StructType::get(
context, fields), 0});
362 if (variants.empty())
363 lowering.stateType = hw::StructType::get(
context, {});
365 lowering.stateType = hw::UnionType::get(
context, variants);
381 auto defineOp = lowering.defineOp;
382 auto loc = defineOp.getLoc();
383 auto *
context = defineOp.getContext();
384 unsigned pcWidth = lowering.pcType.getWidth();
388 SmallVector<Type> inputTypes{lowering.stateType, lowering.pcType};
389 llvm::append_range(inputTypes, defineOp.getArgumentTypes());
390 SmallVector<Type> resultTypes{lowering.stateType, lowering.pcType};
391 llvm::append_range(resultTypes, defineOp.getResultTypes());
392 OpBuilder builder(defineOp);
394 func::FuncOp::create(builder, loc, defineOp.getSymName(),
395 builder.getFunctionType(inputTypes, resultTypes));
398 funcOp.getBody().getBlocks().splice(funcOp.getBody().end(),
399 defineOp.getBody().getBlocks());
400 auto *entryBlock = &funcOp.getBody().front();
401 auto *dispatchBlock =
402 builder.createBlock(&funcOp.getBody(), funcOp.getBody().begin());
403 Value stateArg = dispatchBlock->addArgument(lowering.stateType, loc);
404 Value pcArg = dispatchBlock->addArgument(lowering.pcType, loc);
405 SmallVector<Value> callerArgs;
406 for (
auto arg : entryBlock->getArguments())
407 callerArgs.push_back(
408 dispatchBlock->addArgument(arg.getType(), arg.getLoc()));
413 SmallVector<APInt> caseValues;
414 SmallVector<Block *> caseBlocks;
415 for (
auto [index, resumeBlock] : llvm::enumerate(lowering.resumeBlocks)) {
416 auto *trampolineBlock = builder.createBlock(resumeBlock);
417 SmallVector<Value> operands = callerArgs;
418 if (
auto variant = lowering.getVariant(index)) {
420 hw::UnionExtractOp::create(builder, loc, stateArg, variant->name);
421 auto explodeOp = hw::StructExplodeOp::create(builder, loc, variantValue);
422 llvm::append_range(operands, explodeOp.getResults());
424 cf::BranchOp::create(builder, loc, resumeBlock, operands);
425 caseValues.push_back(APInt(pcWidth, index + 1));
426 caseBlocks.push_back(trampolineBlock);
432 builder.setInsertionPointToEnd(dispatchBlock);
433 if (lowering.resumeBlocks.empty()) {
434 cf::BranchOp::create(builder, loc, entryBlock, callerArgs);
436 SmallVector<ValueRange> caseOperands(caseBlocks.size());
437 cf::SwitchOp::create(builder, loc, pcArg, entryBlock, callerArgs,
438 caseValues, caseBlocks, caseOperands);
443 DenseMap<Block *, unsigned> resumePCs;
444 for (
auto [index, block] : llvm::enumerate(lowering.resumeBlocks))
445 resumePCs[block] = index + 1;
447 auto lowerTerminator = [&](Operation *op, Value state, uint64_t pc,
448 ValueRange yieldOperands) {
449 OpBuilder builder(op);
451 state = ub::PoisonOp::create(builder, op->getLoc(), lowering.stateType,
455 SmallVector<Value> operands{state, pcValue};
456 llvm::append_range(operands, yieldOperands);
457 func::ReturnOp::create(builder, op->getLoc(), operands);
461 for (
auto &block : funcOp.getBody()) {
462 TypeSwitch<Operation *>(block.getTerminator())
463 .Case<CoroutineYieldOp>([&](CoroutineYieldOp op) {
467 unsigned pc = resumePCs.lookup(op.getDest());
469 if (
auto variant = lowering.getVariant(pc - 1)) {
470 OpBuilder builder(op);
472 builder, op.getLoc(), variant->type, op.getDestOperands());
473 state = hw::UnionCreateOp::create(builder, op.getLoc(),
474 lowering.stateType, variant->name,
477 lowerTerminator(op, state, pc, op.getYieldOperands());
479 .Case<CoroutineReturnOp>([&](CoroutineReturnOp op) {
480 lowerTerminator(op, Value{}, lowering.getReturnPC(),
481 op.getYieldOperands());
483 .Case<CoroutineHaltOp>([&](CoroutineHaltOp op) {
484 lowerTerminator(op, Value{}, lowering.getHaltPC(),
485 op.getYieldOperands());
497struct LowerCoroutinesPass
498 :
public arc::impl::LowerCoroutinesPassBase<LowerCoroutinesPass> {
499 void runOnOperation()
override;
503void LowerCoroutinesPass::runOnOperation() {
504 auto module = getOperation();
510 SmallVector<CoroutineDefineOp> defineOps;
511 auto walkResult =
module->walk([&](Operation *op) {
512 if (auto instanceOp = dyn_cast<CoroutineInstanceOp>(op)) {
513 instanceOp.emitOpError("must be lowered before LowerCoroutines");
514 return WalkResult::interrupt();
516 if (
auto defineOp = dyn_cast<CoroutineDefineOp>(op))
517 defineOps.push_back(defineOp);
518 return WalkResult::advance();
520 if (walkResult.wasInterrupted())
521 return signalPassFailure();
528 DenseMap<StringAttr, CoroutineLowering> lowerings;
529 for (
auto defineOp : defineOps) {
531 return signalPassFailure();
541 enum class Color { Unvisited, InProgress, Done };
542 DenseMap<StringAttr, Color> colors;
543 std::function<LogicalResult(StringAttr)> checkCycles =
544 [&](StringAttr name) -> LogicalResult {
545 auto it = lowerings.find(name);
546 if (it == lowerings.end())
548 if (colors.lookup(name) == Color::Done)
550 if (colors.lookup(name) == Color::InProgress)
551 return it->second.defineOp.emitOpError(
552 "recursive coroutines are not supported");
553 colors[name] = Color::InProgress;
554 auto result = success();
555 it->second.stateType.walk([&](CoroutineStateType type) {
556 if (failed(checkCycles(type.getCoroutine().getAttr())))
559 colors[name] = Color::Done;
562 for (
auto defineOp : defineOps)
563 if (failed(checkCycles(defineOp.getSymNameAttr())))
564 return signalPassFailure();
568 for (
auto defineOp : defineOps)
576 bool hasUnknownCoroutines =
false;
577 auto lookupLowering =
578 [&](FlatSymbolRefAttr coroutine) -> CoroutineLowering * {
579 auto it = lowerings.find(coroutine.getAttr());
580 if (it != lowerings.end())
582 hasUnknownCoroutines =
true;
583 mlir::emitError(module.getLoc())
584 <<
"coroutine type references unknown coroutine " << coroutine;
587 AttrTypeReplacer replacer;
588 replacer.addReplacement([&](CoroutineStateType type) -> std::optional<Type> {
589 if (
auto *lowering = lookupLowering(type.getCoroutine()))
590 return lowering->stateType;
593 replacer.addReplacement([&](CoroutinePCType type) -> std::optional<Type> {
594 if (
auto *lowering = lookupLowering(type.getCoroutine()))
595 return lowering->pcType;
602 SmallVector<Operation *> opsToLower;
603 module->walk([&](Operation *op) {
604 if (isa<CoroutineCallOp, CoroutineStartPCOp, CoroutineUndefinedStateOp,
605 CoroutinePCIsReturnOp, CoroutinePCIsHaltOp>(op))
606 opsToLower.push_back(op);
612 auto getPCWidth = [&](Value pc) -> std::optional<unsigned> {
613 if (
auto intType = dyn_cast<IntegerType>(pc.getType()))
614 return intType.getWidth();
615 if (
auto intType = dyn_cast<IntegerType>(replacer.replace(pc.getType())))
616 return intType.getWidth();
620 for (
auto *op : opsToLower) {
621 OpBuilder builder(op);
622 TypeSwitch<Operation *>(op)
623 .Case<CoroutineCallOp>([&](CoroutineCallOp op) {
627 llvm::map_to_vector(op.getResultTypes(), [&](Type type) {
628 return replacer.replace(type);
631 func::CallOp::create(builder, op.getLoc(), op.getCalleeAttr(),
632 resultTypes, op.getOperands());
633 op->replaceAllUsesWith(callOp);
636 .Case<CoroutineStartPCOp>([&](CoroutineStartPCOp op) {
637 auto pcType = dyn_cast<IntegerType>(replacer.replace(op.getType()));
641 APInt(pcType.getWidth(), 0));
642 op->replaceAllUsesWith(ValueRange{value});
645 .Case<CoroutineUndefinedStateOp>([&](CoroutineUndefinedStateOp op) {
648 auto stateType = replacer.replace(op.getType());
649 if (stateType == op.getType())
652 ub::PoisonOp::create(builder, op.getLoc(), stateType,
653 ub::PoisonAttr::get(builder.getContext()));
654 op->replaceAllUsesWith(ValueRange{value});
657 .Case<CoroutinePCIsReturnOp, CoroutinePCIsHaltOp>([&](
auto op) {
660 Value pc = op->getOperand(0);
661 auto pcWidth = getPCWidth(pc);
664 auto sentinel = APInt::getAllOnes(*pcWidth);
665 if (isa<CoroutinePCIsReturnOp>(op))
669 Value cmpValue = comb::ICmpOp::create(
670 builder, op.getLoc(), comb::ICmpPredicate::eq, pc, constValue);
671 op->replaceAllUsesWith(ValueRange{cmpValue});
679 replacer.recursivelyReplaceElementsIn(module,
true,
682 if (hasUnknownCoroutines)
683 return signalPassFailure();
static std::unique_ptr< Context > context
static CoroutineLowering analyzeDefinition(CoroutineDefineOp defineOp)
Determine the resume blocks of a coroutine and derive its concrete PC and state types.
static void lowerDefinition(CoroutineLowering &lowering)
Replace a coroutine definition with a state machine function.
static LogicalResult captureValue(Value value, Region ®ion, ArrayRef< Block * > captureBlocks, Liveness &liveness, DominanceInfo &dominance)
Capture a single value as a trailing block argument of each of the given resume blocks and rewrite al...
static SmallVector< Block * > collectResumeBlocks(Region ®ion)
Collect the resume blocks of a coroutine body, i.e.
static LogicalResult captureValuesAcrossSuspension(CoroutineDefineOp defineOp)
Rewrite the body of a coroutine such that every resume block captures the values that are live across...
static void rematerializeConstant(Value value, Region ®ion, ArrayRef< Block * > captureBlocks, Liveness &liveness)
Clone a constant into the using blocks that are reachable from a capturing resume block,...
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
create(elements, Type result_type=None)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.