CIRCT 23.0.0git
Loading...
Searching...
No Matches
LowerCoroutines.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Lower `arc.coroutine.define` ops into state machine functions. Coroutines
10// are functions that can suspend execution at `arc.coroutine.yield` ops and
11// be resumed at a later point. This pass converts each coroutine definition
12// into a plain `func.func` that dispatches on an explicit integer program
13// counter (PC) argument, and that persists all values live across suspension
14// points in an explicit state argument.
15//
16// The lowering proceeds in two steps. The first step lowers each coroutine
17// definition independently of all others:
18//
19// - Rewrite the body such that every resume block receives the values that
20// are live across its suspension points as trailing block arguments, which
21// makes the state that has to be persisted explicit. Values that do not
22// cross a suspension point are left untouched, and constants are cloned
23// into the resumed blocks instead, since they are cheaper to rematerialize
24// than to persist.
25//
26// - Assign a PC value to each resume block and derive the concrete PC and
27// state types. The PC type is an integer just wide enough to encode the
28// start PC (0), one resume PC per resume block (1 to N), and the return and
29// halt sentinels (the two largest values). The state type is a union with
30// one struct variant per resume block that has values to persist.
31//
32// - Replace the definition with a `func.func` that takes the state and PC as
33// leading arguments, dispatches on the PC to either the original entry
34// block or one of the resume blocks (unpacking the corresponding state
35// variant), and returns the new state, resume PC, and yielded values at
36// each suspension point.
37//
38// At this point, the state and PC types of *other* coroutines may still occur
39// as opaque `!arc.coroutine_state` and `!arc.coroutine_pc` types within the
40// lowered functions, for example on calls to nested coroutines. The second
41// step performs a single global sweep over the module that concretizes all
42// such occurrences:
43//
44// - `arc.coroutine.call` ops become plain `func.call` ops, and the auxiliary
45// coroutine ops become integer constants, comparisons, and poison values.
46//
47// - All remaining occurrences of the opaque types -- in block arguments,
48// results of unrelated ops, function signatures, and nested within
49// aggregate types -- are replaced by the concrete types computed in step
50// one. Since a coroutine's persistent state may contain the state of the
51// coroutines it calls, this replacement is recursive. Cyclic state
52// containment, i.e. recursive coroutines, is detected and rejected.
53//
54//===----------------------------------------------------------------------===//
55
60#include "mlir/Analysis/Liveness.h"
61#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
62#include "mlir/Dialect/Func/IR/FuncOps.h"
63#include "mlir/Dialect/UB/IR/UBOps.h"
64#include "mlir/IR/AttrTypeSubElements.h"
65#include "mlir/IR/Dominance.h"
66#include "mlir/Pass/Pass.h"
67#include "llvm/ADT/TypeSwitch.h"
68#include "llvm/Support/GenericIteratedDominanceFrontier.h"
69
70namespace circt {
71namespace arc {
72#define GEN_PASS_DEF_LOWERCOROUTINESPASS
73#include "circt/Dialect/Arc/ArcPasses.h.inc"
74} // namespace arc
75} // namespace circt
76
77using namespace circt;
78using namespace arc;
79using namespace mlir;
80
81//===----------------------------------------------------------------------===//
82// Suspension Value Capture
83//===----------------------------------------------------------------------===//
84//
85// The following functions rewrite the body of a coroutine such that every
86// resume block receives the values that are live across the suspension points
87// it resumes from as trailing block arguments. After the lowering, control
88// re-enters a resume block directly from the dispatch logic, so the
89// definitions of these values no longer dominate their uses. Capturing them as
90// block arguments makes the state that has to be persisted explicit, and the
91// lowering picks the arguments up as the contents of the resume block's state
92// variant.
93//
94// Values that do not cross a suspension point are left untouched and keep
95// using their original definition through dominance. Where a control flow
96// path carrying a captured value rejoins a path carrying the original
97// definition, the join block receives a merging block argument as well.
98// Constants are not captured; they are cloned into the using blocks that are
99// reachable from a resume block instead, since they are cheaper to
100// rematerialize on re-entry than to persist across suspension points.
101
102/// Collect the resume blocks of a coroutine body, i.e. the blocks targeted by
103/// yield ops, in region order to make the lowering deterministic.
104static SmallVector<Block *> collectResumeBlocks(Region &region) {
105 SmallPtrSet<Block *, 8> resumeBlockSet;
106 for (auto &block : region)
107 if (auto yieldOp = dyn_cast<CoroutineYieldOp>(block.getTerminator()))
108 resumeBlockSet.insert(yieldOp.getDest());
109 SmallVector<Block *> resumeBlocks;
110 for (auto &block : region)
111 if (resumeBlockSet.contains(&block))
112 resumeBlocks.push_back(&block);
113 return resumeBlocks;
114}
115
116/// Capture a single value as a trailing block argument of each of the given
117/// resume blocks and rewrite all affected uses of the value.
118static LogicalResult captureValue(Value value, Region &region,
119 ArrayRef<Block *> captureBlocks,
120 Liveness &liveness,
121 DominanceInfo &dominance) {
122 auto *defBlock = value.getParentBlock();
123
124 // Determine the blocks that receive an argument for the value. The resume
125 // blocks capture it directly. In addition, wherever a path carrying the
126 // captured value rejoins a path carrying the original definition, the join
127 // block needs a merging argument. These join blocks are the iterated
128 // dominance frontier of the capture blocks plus the original definition,
129 // pruned to the blocks where the value is live-in.
130 auto &domTree = dominance.getDomTree(&region);
131 llvm::IDFCalculatorBase<Block, false> idfCalculator(domTree);
132
133 SmallPtrSet<Block *, 8> definingBlocks(captureBlocks.begin(),
134 captureBlocks.end());
135 definingBlocks.insert(defBlock);
136 idfCalculator.setDefiningBlocks(definingBlocks);
137
138 SmallPtrSet<Block *, 16> liveInBlocks;
139 for (auto &block : region)
140 if (liveness.getLiveness(&block)->isLiveIn(value))
141 liveInBlocks.insert(&block);
142 idfCalculator.setLiveInBlocks(liveInBlocks);
143
144 SmallVector<Block *> mergeBlocks;
145 idfCalculator.calculate(mergeBlocks);
146
147 SmallPtrSet<Block *, 16> argBlocks(mergeBlocks.begin(), mergeBlocks.end());
148 argBlocks.insert(captureBlocks.begin(), captureBlocks.end());
149
150 // Since the value is an SSA value, its defining block dominates its entire
151 // live range, and with it all argument blocks and their predecessors. Walk
152 // the dominator tree from there, tracking the reaching definition of the
153 // value, which is the original value until an argument block redefines it.
154 // Rewrite all uses to the reaching definition of their block, and pass the
155 // definition at the end of each block into any argument block successors.
156 struct WorklistItem {
157 DominanceInfoNode *domNode;
158 Value reachingDef;
159 };
160 SmallVector<WorklistItem> worklist;
161 worklist.push_back({domTree.getNode(defBlock), value});
162
163 while (!worklist.empty()) {
164 auto item = worklist.pop_back_val();
165 auto *block = item.domNode->getBlock();
166
167 if (argBlocks.contains(block))
168 item.reachingDef = block->addArgument(value.getType(), value.getLoc());
169
170 // Rewrite the uses in this block, including in nested regions.
171 if (item.reachingDef != value)
172 block->walk([&](Operation *nestedOp) {
173 nestedOp->replaceUsesOfWith(value, item.reachingDef);
174 });
175
176 // Append the reaching definition to the successor operands of every edge
177 // into an argument block.
178 auto *terminator = block->getTerminator();
179 auto branchOp = dyn_cast<BranchOpInterface>(terminator);
180 for (auto &blockOperand : terminator->getBlockOperands()) {
181 if (!argBlocks.contains(blockOperand.get()))
182 continue;
183 if (!branchOp)
184 return terminator->emitOpError()
185 << "does not implement `BranchOpInterface`; cannot pass value "
186 "into successor block";
187 branchOp.getSuccessorOperands(blockOperand.getOperandNumber())
188 .append(item.reachingDef);
189 }
190
191 for (auto *child : item.domNode->children())
192 worklist.push_back({child, item.reachingDef});
193 }
194
195 return success();
196}
197
198/// Clone a constant into the using blocks that are reachable from a capturing
199/// resume block, where the original definition no longer dominates its uses
200/// after the lowering. All other uses keep using the original.
201static void rematerializeConstant(Value value, Region &region,
202 ArrayRef<Block *> captureBlocks,
203 Liveness &liveness) {
204 // Collect the blocks reachable from a capturing resume block in which the
205 // value is live-in.
206 DenseSet<Block *> resumedBlocks(captureBlocks.begin(), captureBlocks.end());
207 SmallVector<Block *> worklist(captureBlocks.begin(), captureBlocks.end());
208 while (!worklist.empty())
209 for (auto *successor : worklist.pop_back_val()->getSuccessors())
210 if (liveness.getLiveIn(successor).contains(value) &&
211 resumedBlocks.insert(successor).second)
212 worklist.push_back(successor);
213
214 // Redirect the uses in these blocks to a per-block clone of the constant.
215 auto *defOp = value.getDefiningOp();
216 DenseMap<Block *, Value> clones;
217 for (auto &use : llvm::make_early_inc_range(value.getUses())) {
218 auto *block = region.findAncestorBlockInRegion(*use.getOwner()->getBlock());
219 if (!resumedBlocks.contains(block))
220 continue;
221 auto &clone = clones[block];
222 if (!clone) {
223 OpBuilder builder(block, block->begin());
224 clone = builder.clone(*defOp)->getResult(0);
225 }
226 use.set(clone);
227 }
228 if (defOp->use_empty())
229 defOp->erase();
230}
231
232/// Rewrite the body of a coroutine such that every resume block captures the
233/// values that are live across the suspension points it resumes from as
234/// trailing block arguments.
235static LogicalResult captureValuesAcrossSuspension(CoroutineDefineOp defineOp) {
236 Region &region = defineOp.getBody();
237 if (region.hasOneBlock())
238 return success();
239
240 auto resumeBlocks = collectResumeBlocks(region);
241 if (resumeBlocks.empty())
242 return success();
243
244 // Compute the liveness and dominance of the original body once. The
245 // per-value rewrites below only ever change the liveness of the value
246 // currently being processed, never that of the other collected values, and
247 // they add no blocks or control flow edges, so both analyses remain valid
248 // throughout the loop.
249 Liveness liveness(defineOp);
250 DominanceInfo dominance(defineOp);
251
252 // Collect the candidate values up front, since the rewriting below adds new
253 // block arguments and constants which need no further treatment.
254 SmallVector<Value> values;
255 for (auto &block : region) {
256 llvm::append_range(values, block.getArguments());
257 for (auto &op : block)
258 llvm::append_range(values, op.getResults());
259 }
260
261 for (auto value : values) {
262 // Collect the resume blocks where the value is live-in. Values that are
263 // not live across any suspension point still dominate their uses after
264 // the lowering and need no rewriting at all.
265 SmallVector<Block *> captureBlocks;
266 for (auto *block : resumeBlocks)
267 if (liveness.getLiveIn(block).contains(value))
268 captureBlocks.push_back(block);
269 if (captureBlocks.empty())
270 continue;
271
272 // Rematerialize constants in the resumed blocks instead of capturing
273 // them; capture all other values as trailing resume block arguments.
274 auto *defOp = value.getDefiningOp();
275 if (defOp && defOp->hasTrait<OpTrait::ConstantLike>()) {
276 rematerializeConstant(value, region, captureBlocks, liveness);
277 continue;
278 }
279 if (failed(captureValue(value, region, captureBlocks, liveness, dominance)))
280 return failure();
281 }
282
283 return success();
284}
285
286//===----------------------------------------------------------------------===//
287// Coroutine Analysis
288//===----------------------------------------------------------------------===//
289
290namespace {
291/// Per-coroutine lowering info: the resume blocks and the concrete PC and
292/// state types derived from the body after the values live across suspension
293/// points have been captured.
294struct CoroutineLowering {
295 /// The original coroutine definition.
296 CoroutineDefineOp defineOp;
297 /// The blocks targeted by yield ops, in region order. The PC value of a
298 /// resume block is its index in this list plus one.
299 SmallVector<Block *> resumeBlocks;
300 /// The union field index holding the persisted state of each resume block,
301 /// or none if the resume block persists no state.
302 SmallVector<std::optional<unsigned>> variantIndices;
303 /// The concrete PC type.
304 IntegerType pcType;
305 /// The concrete state type. This is a union with one struct variant per
306 /// state-persisting resume block, and may still contain the opaque state
307 /// and PC types of other coroutines.
308 Type stateType;
309
310 /// Returns the sentinel PC value indicating that the coroutine returned.
311 uint64_t getReturnPC() { return getHaltPC() - 1; }
312 /// Returns the sentinel PC value indicating that the coroutine halted.
313 uint64_t getHaltPC() {
314 return APInt::getAllOnes(pcType.getWidth()).getZExtValue();
315 }
316 /// Returns the union variant persisting the state of the resume block with
317 /// the given index, or none if the resume block persists no state.
318 std::optional<hw::UnionType::FieldInfo> getVariant(unsigned resumeIndex) {
319 if (auto fieldIndex = variantIndices[resumeIndex])
320 return cast<hw::UnionType>(stateType).getElements()[*fieldIndex];
321 return std::nullopt;
322 }
323};
324} // namespace
325
326/// Determine the resume blocks of a coroutine and derive its concrete PC and
327/// state types. To be called after the values live across suspension points
328/// have been captured, such that the trailing block arguments of each resume
329/// block are exactly the values that have to be persisted.
330static CoroutineLowering analyzeDefinition(CoroutineDefineOp defineOp) {
331 auto *context = defineOp.getContext();
332 CoroutineLowering lowering;
333 lowering.defineOp = defineOp;
334
335 lowering.resumeBlocks = collectResumeBlocks(defineOp.getBody());
336
337 // The PC type must be wide enough to encode the start PC (0), one resume PC
338 // per resume block (1 to N), and the return and halt sentinels (the two
339 // largest values).
340 unsigned pcWidth = llvm::Log2_64_Ceil(lowering.resumeBlocks.size() + 3);
341 lowering.pcType = IntegerType::get(context, pcWidth);
342
343 // Build the state type as a union with one struct variant per resume block
344 // that has values to persist. The variants are named after the resume PC,
345 // and the struct fields after the resume block's trailing arguments.
346 unsigned numCoroutineArgs = defineOp.getArgumentTypes().size();
347 SmallVector<hw::UnionType::FieldInfo> variants;
348 for (auto [index, block] : llvm::enumerate(lowering.resumeBlocks)) {
349 auto persistedArgs = block->getArguments().drop_front(numCoroutineArgs);
350 if (persistedArgs.empty()) {
351 lowering.variantIndices.push_back(std::nullopt);
352 continue;
353 }
354 SmallVector<hw::StructType::FieldInfo> fields;
355 for (auto [fieldIndex, arg] : llvm::enumerate(persistedArgs))
356 fields.push_back(
357 {StringAttr::get(context, "f" + Twine(fieldIndex)), arg.getType()});
358 lowering.variantIndices.push_back(variants.size());
359 variants.push_back({StringAttr::get(context, "r" + Twine(index + 1)),
360 hw::StructType::get(context, fields), 0});
361 }
362 if (variants.empty())
363 lowering.stateType = hw::StructType::get(context, {});
364 else
365 lowering.stateType = hw::UnionType::get(context, variants);
366
367 return lowering;
368}
369
370//===----------------------------------------------------------------------===//
371// Coroutine Lowering
372//===----------------------------------------------------------------------===//
373
374/// Replace a coroutine definition with a state machine function. The function
375/// takes the persistent state and PC as leading arguments and dispatches on
376/// the PC to either the original entry block or, through a trampoline block
377/// that unpacks the corresponding state variant, to one of the resume blocks.
378/// The coroutine terminators become function returns that produce the new
379/// state, resume PC, and yielded values.
380static void lowerDefinition(CoroutineLowering &lowering) {
381 auto defineOp = lowering.defineOp;
382 auto loc = defineOp.getLoc();
383 auto *context = defineOp.getContext();
384 unsigned pcWidth = lowering.pcType.getWidth();
385
386 // Create the replacement function. The signature wraps the coroutine's
387 // function type with the state and PC as leading arguments and results.
388 SmallVector<Type> inputTypes{lowering.stateType, lowering.pcType};
389 llvm::append_range(inputTypes, defineOp.getArgumentTypes());
390 SmallVector<Type> resultTypes{lowering.stateType, lowering.pcType};
391 llvm::append_range(resultTypes, defineOp.getResultTypes());
392 OpBuilder builder(defineOp);
393 auto funcOp =
394 func::FuncOp::create(builder, loc, defineOp.getSymName(),
395 builder.getFunctionType(inputTypes, resultTypes));
396
397 // Move the body over and create the dispatch block in front of it.
398 funcOp.getBody().getBlocks().splice(funcOp.getBody().end(),
399 defineOp.getBody().getBlocks());
400 auto *entryBlock = &funcOp.getBody().front();
401 auto *dispatchBlock =
402 builder.createBlock(&funcOp.getBody(), funcOp.getBody().begin());
403 Value stateArg = dispatchBlock->addArgument(lowering.stateType, loc);
404 Value pcArg = dispatchBlock->addArgument(lowering.pcType, loc);
405 SmallVector<Value> callerArgs;
406 for (auto arg : entryBlock->getArguments())
407 callerArgs.push_back(
408 dispatchBlock->addArgument(arg.getType(), arg.getLoc()));
409
410 // Create a trampoline block for each resume block that unpacks the
411 // persisted values from the corresponding state variant and passes them to
412 // the resume block, alongside the fresh caller-supplied arguments.
413 SmallVector<APInt> caseValues;
414 SmallVector<Block *> caseBlocks;
415 for (auto [index, resumeBlock] : llvm::enumerate(lowering.resumeBlocks)) {
416 auto *trampolineBlock = builder.createBlock(resumeBlock);
417 SmallVector<Value> operands = callerArgs;
418 if (auto variant = lowering.getVariant(index)) {
419 Value variantValue =
420 hw::UnionExtractOp::create(builder, loc, stateArg, variant->name);
421 auto explodeOp = hw::StructExplodeOp::create(builder, loc, variantValue);
422 llvm::append_range(operands, explodeOp.getResults());
423 }
424 cf::BranchOp::create(builder, loc, resumeBlock, operands);
425 caseValues.push_back(APInt(pcWidth, index + 1));
426 caseBlocks.push_back(trampolineBlock);
427 }
428
429 // Dispatch on the PC. The start PC enters the original entry block;
430 // passing a return or halt PC is undefined behavior, so those simply fall
431 // into the default case alongside the start PC.
432 builder.setInsertionPointToEnd(dispatchBlock);
433 if (lowering.resumeBlocks.empty()) {
434 cf::BranchOp::create(builder, loc, entryBlock, callerArgs);
435 } else {
436 SmallVector<ValueRange> caseOperands(caseBlocks.size());
437 cf::SwitchOp::create(builder, loc, pcArg, entryBlock, callerArgs,
438 caseValues, caseBlocks, caseOperands);
439 }
440
441 // Replace the coroutine terminators with function returns producing the new
442 // state, resume PC, and yielded values.
443 DenseMap<Block *, unsigned> resumePCs;
444 for (auto [index, block] : llvm::enumerate(lowering.resumeBlocks))
445 resumePCs[block] = index + 1;
446
447 auto lowerTerminator = [&](Operation *op, Value state, uint64_t pc,
448 ValueRange yieldOperands) {
449 OpBuilder builder(op);
450 if (!state)
451 state = ub::PoisonOp::create(builder, op->getLoc(), lowering.stateType,
452 ub::PoisonAttr::get(context));
453 Value pcValue =
454 hw::ConstantOp::create(builder, op->getLoc(), APInt(pcWidth, pc));
455 SmallVector<Value> operands{state, pcValue};
456 llvm::append_range(operands, yieldOperands);
457 func::ReturnOp::create(builder, op->getLoc(), operands);
458 op->erase();
459 };
460
461 for (auto &block : funcOp.getBody()) {
462 TypeSwitch<Operation *>(block.getTerminator())
463 .Case<CoroutineYieldOp>([&](CoroutineYieldOp op) {
464 // Pack the values to persist into the state variant corresponding
465 // to the destination block. Resume blocks without persisted state
466 // have no variant and return a poison state.
467 unsigned pc = resumePCs.lookup(op.getDest());
468 Value state;
469 if (auto variant = lowering.getVariant(pc - 1)) {
470 OpBuilder builder(op);
471 Value variantValue = hw::StructCreateOp::create(
472 builder, op.getLoc(), variant->type, op.getDestOperands());
473 state = hw::UnionCreateOp::create(builder, op.getLoc(),
474 lowering.stateType, variant->name,
475 variantValue);
476 }
477 lowerTerminator(op, state, pc, op.getYieldOperands());
478 })
479 .Case<CoroutineReturnOp>([&](CoroutineReturnOp op) {
480 lowerTerminator(op, Value{}, lowering.getReturnPC(),
481 op.getYieldOperands());
482 })
483 .Case<CoroutineHaltOp>([&](CoroutineHaltOp op) {
484 lowerTerminator(op, Value{}, lowering.getHaltPC(),
485 op.getYieldOperands());
486 });
487 }
488
489 defineOp.erase();
490}
491
492//===----------------------------------------------------------------------===//
493// Pass Implementation
494//===----------------------------------------------------------------------===//
495
496namespace {
497struct LowerCoroutinesPass
498 : public arc::impl::LowerCoroutinesPassBase<LowerCoroutinesPass> {
499 void runOnOperation() override;
500};
501} // namespace
502
503void LowerCoroutinesPass::runOnOperation() {
504 auto module = getOperation();
505
506 // Collect the coroutine definitions to lower, and reject any leftover
507 // coroutine instances in the same pass over the module. Instances are
508 // expected to be lowered into explicit storage and calls beforehand;
509 // rejecting them here avoids breaking their symbol references.
510 SmallVector<CoroutineDefineOp> defineOps;
511 auto walkResult = module->walk([&](Operation *op) {
512 if (auto instanceOp = dyn_cast<CoroutineInstanceOp>(op)) {
513 instanceOp.emitOpError("must be lowered before LowerCoroutines");
514 return WalkResult::interrupt();
515 }
516 if (auto defineOp = dyn_cast<CoroutineDefineOp>(op))
517 defineOps.push_back(defineOp);
518 return WalkResult::advance();
519 });
520 if (walkResult.wasInterrupted())
521 return signalPassFailure();
522
523 // Step 1: Lower each coroutine definition independently. Capture the values
524 // live across suspension points as trailing resume block arguments and
525 // derive the concrete PC and state types. The state types may still contain
526 // opaque types of other coroutines, which the global sweep below
527 // concretizes.
528 DenseMap<StringAttr, CoroutineLowering> lowerings;
529 for (auto defineOp : defineOps) {
530 if (failed(captureValuesAcrossSuspension(defineOp)))
531 return signalPassFailure();
532 lowerings.insert({defineOp.getSymNameAttr(), analyzeDefinition(defineOp)});
533 }
534
535 // Reject recursive coroutines. A coroutine whose persistent state
536 // transitively contains its own state would require unbounded storage.
537 // Detect cycles in the state containment graph with a depth-first
538 // traversal. This must run before step 2, whose recursive type replacement
539 // would not terminate on cycles, and before the definitions are erased, so
540 // the error can point at the offending op.
541 enum class Color { Unvisited, InProgress, Done };
542 DenseMap<StringAttr, Color> colors;
543 std::function<LogicalResult(StringAttr)> checkCycles =
544 [&](StringAttr name) -> LogicalResult {
545 auto it = lowerings.find(name);
546 if (it == lowerings.end())
547 return success();
548 if (colors.lookup(name) == Color::Done)
549 return success();
550 if (colors.lookup(name) == Color::InProgress)
551 return it->second.defineOp.emitOpError(
552 "recursive coroutines are not supported");
553 colors[name] = Color::InProgress;
554 auto result = success();
555 it->second.stateType.walk([&](CoroutineStateType type) {
556 if (failed(checkCycles(type.getCoroutine().getAttr())))
557 result = failure();
558 });
559 colors[name] = Color::Done;
560 return result;
561 };
562 for (auto defineOp : defineOps)
563 if (failed(checkCycles(defineOp.getSymNameAttr())))
564 return signalPassFailure();
565
566 // Replace each coroutine definition with a state machine function,
567 // completing step 1.
568 for (auto defineOp : defineOps)
569 lowerDefinition(lowerings.find(defineOp.getSymNameAttr())->second);
570
571 // Step 2: Concretize all occurrences of the opaque coroutine state and PC
572 // types throughout the module. The replacer recurses into the replacement
573 // types, such that a coroutine state containing the state of a nested
574 // coroutine gets fully concretized; the cycle check above guarantees that
575 // this recursion terminates.
576 bool hasUnknownCoroutines = false;
577 auto lookupLowering =
578 [&](FlatSymbolRefAttr coroutine) -> CoroutineLowering * {
579 auto it = lowerings.find(coroutine.getAttr());
580 if (it != lowerings.end())
581 return &it->second;
582 hasUnknownCoroutines = true;
583 mlir::emitError(module.getLoc())
584 << "coroutine type references unknown coroutine " << coroutine;
585 return nullptr;
586 };
587 AttrTypeReplacer replacer;
588 replacer.addReplacement([&](CoroutineStateType type) -> std::optional<Type> {
589 if (auto *lowering = lookupLowering(type.getCoroutine()))
590 return lowering->stateType;
591 return std::nullopt;
592 });
593 replacer.addReplacement([&](CoroutinePCType type) -> std::optional<Type> {
594 if (auto *lowering = lookupLowering(type.getCoroutine()))
595 return lowering->pcType;
596 return std::nullopt;
597 });
598
599 // Rewrite the ops that consume or produce values of the opaque types. This
600 // must happen before the type sweep below, since some of these ops identify
601 // their coroutine solely through the symbol carried in their types.
602 SmallVector<Operation *> opsToLower;
603 module->walk([&](Operation *op) {
604 if (isa<CoroutineCallOp, CoroutineStartPCOp, CoroutineUndefinedStateOp,
605 CoroutinePCIsReturnOp, CoroutinePCIsHaltOp>(op))
606 opsToLower.push_back(op);
607 });
608
609 // Determine the PC width for a sentinel check. The PC operand either still
610 // has the opaque PC type, or has already been concretized to an integer if
611 // its producer was rewritten earlier in the loop below.
612 auto getPCWidth = [&](Value pc) -> std::optional<unsigned> {
613 if (auto intType = dyn_cast<IntegerType>(pc.getType()))
614 return intType.getWidth();
615 if (auto intType = dyn_cast<IntegerType>(replacer.replace(pc.getType())))
616 return intType.getWidth();
617 return std::nullopt;
618 };
619
620 for (auto *op : opsToLower) {
621 OpBuilder builder(op);
622 TypeSwitch<Operation *>(op)
623 .Case<CoroutineCallOp>([&](CoroutineCallOp op) {
624 // The lowered function signature matches the coroutine call ABI
625 // exactly, so the call maps to a plain function call.
626 auto resultTypes =
627 llvm::map_to_vector(op.getResultTypes(), [&](Type type) {
628 return replacer.replace(type);
629 });
630 auto callOp =
631 func::CallOp::create(builder, op.getLoc(), op.getCalleeAttr(),
632 resultTypes, op.getOperands());
633 op->replaceAllUsesWith(callOp);
634 op->erase();
635 })
636 .Case<CoroutineStartPCOp>([&](CoroutineStartPCOp op) {
637 auto pcType = dyn_cast<IntegerType>(replacer.replace(op.getType()));
638 if (!pcType)
639 return;
640 Value value = hw::ConstantOp::create(builder, op.getLoc(),
641 APInt(pcType.getWidth(), 0));
642 op->replaceAllUsesWith(ValueRange{value});
643 op->erase();
644 })
645 .Case<CoroutineUndefinedStateOp>([&](CoroutineUndefinedStateOp op) {
646 // The state passed on the very first entry into a coroutine is
647 // never read, so any value will do.
648 auto stateType = replacer.replace(op.getType());
649 if (stateType == op.getType())
650 return;
651 Value value =
652 ub::PoisonOp::create(builder, op.getLoc(), stateType,
653 ub::PoisonAttr::get(builder.getContext()));
654 op->replaceAllUsesWith(ValueRange{value});
655 op->erase();
656 })
657 .Case<CoroutinePCIsReturnOp, CoroutinePCIsHaltOp>([&](auto op) {
658 // Use the raw operand instead of the typed ODS getter; the PC may
659 // already have been concretized to an integer (see `getPCWidth`).
660 Value pc = op->getOperand(0);
661 auto pcWidth = getPCWidth(pc);
662 if (!pcWidth)
663 return;
664 auto sentinel = APInt::getAllOnes(*pcWidth);
665 if (isa<CoroutinePCIsReturnOp>(op))
666 sentinel -= 1;
667 Value constValue =
668 hw::ConstantOp::create(builder, op.getLoc(), sentinel);
669 Value cmpValue = comb::ICmpOp::create(
670 builder, op.getLoc(), comb::ICmpPredicate::eq, pc, constValue);
671 op->replaceAllUsesWith(ValueRange{cmpValue});
672 op->erase();
673 });
674 }
675
676 // Sweep over the entire module and replace all remaining occurrences of the
677 // opaque types. This covers block arguments, results of unrelated ops,
678 // function signatures, and types nested within aggregates.
679 replacer.recursivelyReplaceElementsIn(module, /*replaceAttrs=*/true,
680 /*replaceLocs=*/false,
681 /*replaceTypes=*/true);
682 if (hasUnknownCoroutines)
683 return signalPassFailure();
684}
static std::unique_ptr< Context > context
static CoroutineLowering analyzeDefinition(CoroutineDefineOp defineOp)
Determine the resume blocks of a coroutine and derive its concrete PC and state types.
static void lowerDefinition(CoroutineLowering &lowering)
Replace a coroutine definition with a state machine function.
static LogicalResult captureValue(Value value, Region &region, ArrayRef< Block * > captureBlocks, Liveness &liveness, DominanceInfo &dominance)
Capture a single value as a trailing block argument of each of the given resume blocks and rewrite al...
static SmallVector< Block * > collectResumeBlocks(Region &region)
Collect the resume blocks of a coroutine body, i.e.
static LogicalResult captureValuesAcrossSuspension(CoroutineDefineOp defineOp)
Rewrite the body of a coroutine such that every resume block captures the values that are live across...
static void rematerializeConstant(Value value, Region &region, ArrayRef< Block * > captureBlocks, Liveness &liveness)
Clone a constant into the using blocks that are reachable from a capturing resume block,...
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
create(data_type, value)
Definition hw.py:433
create(elements, Type result_type=None)
Definition hw.py:544
Definition arc.py:1
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.