CIRCT 23.0.0git
Loading...
Searching...
No Matches
CoreToFSM.cpp
Go to the documentation of this file.
1//===CoreToFSM.cpp - Convert Core Dialects (HW + Seq + Comb) to FSM Dialect===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17#include "circt/Support/LLVM.h"
18#include "mlir/Analysis/TopologicalSortUtils.h"
19#include "mlir/IR/Block.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinAttributes.h"
22#include "mlir/IR/BuiltinOps.h"
23#include "mlir/IR/Diagnostics.h"
24#include "mlir/IR/IRMapping.h"
25#include "mlir/IR/MLIRContext.h"
26#include "mlir/IR/Value.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Pass/PassManager.h"
29#include "mlir/Transforms/DialectConversion.h"
30#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31#include "mlir/Transforms/Passes.h"
32#include "llvm/ADT/DenseMap.h"
33#include "llvm/ADT/DenseSet.h"
34#include "llvm/ADT/SmallVector.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/LogicalResult.h"
37#include <string>
38
39namespace circt {
40#define GEN_PASS_DEF_CONVERTCORETOFSM
41#include "circt/Conversion/Passes.h.inc"
42} // namespace circt
43
44using namespace mlir;
45using namespace circt;
46using namespace hw;
47using namespace comb;
48using namespace fsm;
49
50namespace {
51
52// Forward declaration for our helper function
53static void generateConcatenatedValues(
54 const llvm::SmallVector<llvm::SetVector<size_t>> &allOperandValues,
55 const llvm::SmallVector<unsigned> &shifts,
56 llvm::SetVector<size_t> &finalPossibleValues);
57
58/// Internal helper with visited set to detect cycles.
59static void addPossibleValuesImpl(llvm::SetVector<size_t> &possibleValues,
60 Value v, llvm::DenseSet<Value> &visited) {
61 // Detect cycles - if we've seen this value before, skip it.
62 if (!visited.insert(v).second)
63 return;
64
65 if (auto c = dyn_cast_or_null<hw::ConstantOp>(v.getDefiningOp())) {
66 possibleValues.insert(c.getValueAttr().getValue().getZExtValue());
67 return;
68 }
69 if (auto m = dyn_cast_or_null<MuxOp>(v.getDefiningOp())) {
70 addPossibleValuesImpl(possibleValues, m.getTrueValue(), visited);
71 addPossibleValuesImpl(possibleValues, m.getFalseValue(), visited);
72 return;
73 }
74
75 if (auto concatOp = dyn_cast_or_null<ConcatOp>(v.getDefiningOp())) {
76 llvm::SmallVector<llvm::SetVector<size_t>> allOperandValues;
77 llvm::SmallVector<unsigned> operandWidths;
78
79 for (Value operand : concatOp.getOperands()) {
80 llvm::SetVector<size_t> operandPossibleValues;
81 addPossibleValuesImpl(operandPossibleValues, operand, visited);
82
83 // It's crucial to handle the case where a sub-computation is too complex.
84 // If we can't determine specific values for an operand, we must
85 // pessimistically assume it can be any value its bitwidth allows.
86 auto opType = dyn_cast<IntegerType>(operand.getType());
87 // comb.concat only accepts signless integer operands by definition in
88 // CIRCT's type system, so this assertion should always hold for valid IR.
89 assert(opType && "comb.concat operand must be an integer type");
90 unsigned width = opType.getWidth();
91 if (operandPossibleValues.empty()) {
92 uint64_t numStates = 1ULL << width;
93 // Add a threshold to prevent combinatorial explosion on large unknown
94 // inputs.
95 if (numStates > 256) { // Heuristic threshold
96 // If the search space is too large, we abandon the analysis for this
97 // path. The outer function will fall back to its own full-range
98 // default.
99 v.getDefiningOp()->emitWarning()
100 << "Search space too large (>" << 256
101 << " states) for operand with bitwidth " << width
102 << "; abandoning analysis for this path";
103 return;
104 }
105 for (uint64_t i = 0; i < numStates; ++i)
106 operandPossibleValues.insert(i);
107 }
108
109 allOperandValues.push_back(operandPossibleValues);
110 operandWidths.push_back(width);
111 }
112
113 // The shift for operand `i` is the sum of the widths of operands `i+1` to
114 // `n-1`.
115 llvm::SmallVector<unsigned> shifts(concatOp.getNumOperands(), 0);
116 for (int i = concatOp.getNumOperands() - 2; i >= 0; --i) {
117 shifts[i] = shifts[i + 1] + operandWidths[i + 1];
118 }
119
120 generateConcatenatedValues(allOperandValues, shifts, possibleValues);
121 return;
122 }
123
124 // --- Fallback Case ---
125 // If the operation is not recognized, assume all possible values for its
126 // bitwidth.
127
128 auto addrType = dyn_cast<IntegerType>(v.getType());
129 if (!addrType)
130 return; // Not an integer type we can analyze
131
132 unsigned bitWidth = addrType.getWidth();
133 // Again, use a threshold to avoid trying to enumerate 2^64 values.
134 if (bitWidth > 16) {
135 if (v.getDefiningOp())
136 v.getDefiningOp()->emitWarning()
137 << "Bitwidth " << bitWidth
138 << " too large (>16); abandoning analysis for this path";
139 return;
140 }
141
142 uint64_t numRegStates = 1ULL << bitWidth;
143 for (size_t i = 0; i < numRegStates; i++) {
144 possibleValues.insert(i);
145 }
146}
147
148static void addPossibleValues(llvm::SetVector<size_t> &possibleValues,
149 Value v) {
150 llvm::DenseSet<Value> visited;
151 addPossibleValuesImpl(possibleValues, v, visited);
152}
153
154/// Checks if two values compute the same function structurally.
155/// Values defined outside their respective regions (e.g., machine block
156/// arguments or fsm.variable results) are compared by SSA identity.
157/// Values defined by operations within their regions are compared using
158/// MLIR's OperationEquivalence with a custom equivalence callback.
159static bool areStructurallyEquivalent(Value a, Value b, Region &regionA,
160 Region &regionB) {
161 if (a == b)
162 return true;
163
164 Operation *opA = a.getDefiningOp();
165 Operation *opB = b.getDefiningOp();
166 if (!opA || !opB)
167 return false;
168
169 bool aIsLocal = regionA.isAncestor(opA->getParentRegion());
170 bool bIsLocal = regionB.isAncestor(opB->getParentRegion());
171 if (aIsLocal != bIsLocal)
172 return false;
173 if (!aIsLocal)
174 return false;
175
176 // Both local: compare result index and delegate structural comparison
177 // to MLIR's OperationEquivalence.
178 if (cast<OpResult>(a).getResultNumber() !=
179 cast<OpResult>(b).getResultNumber())
180 return false;
181
182 return OperationEquivalence::isEquivalentTo(
183 opA, opB,
184 [&](Value lhs, Value rhs) -> LogicalResult {
185 return success(areStructurallyEquivalent(lhs, rhs, regionA, regionB));
186 },
187 /*markEquivalent=*/nullptr, OperationEquivalence::Flags::IgnoreLocations);
188}
189
190/// RewritePattern that folds expressions in an action region that are
191/// structurally equivalent to guard conditions into boolean constants.
192
193class GuardConditionFoldPattern : public RewritePattern {
194public:
195 GuardConditionFoldPattern(MLIRContext *ctx,
196 ArrayRef<std::pair<Value, bool>> guardFacts,
197 Region &guardRegion, Region &actionRegion)
198 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/10, ctx),
199 guardFacts(guardFacts.begin(), guardFacts.end()),
200 guardRegion(guardRegion), actionRegion(actionRegion) {}
201
202 LogicalResult matchAndRewrite(Operation *op,
203 PatternRewriter &rewriter) const override {
204 if (!actionRegion.isAncestor(op->getParentRegion()))
205 return failure();
206
207 // Skip constants because replacing a constant with an identical constant
208 // creates an infinite loop in the greedy rewriter.
209 if (isa<hw::ConstantOp>(op))
210 return failure();
211
212 for (Value result : op->getResults()) {
213 if (!result.getType().isInteger(1) || result.use_empty())
214 continue;
215
216 for (auto [guardExpr, isTrue] : guardFacts) {
217 // Direct structural match.
218 if (areStructurallyEquivalent(guardExpr, result, guardRegion,
219 actionRegion))
220 return replaceWithConstant(result, isTrue, op, rewriter);
221
222 // ICmp predicate inversion: if the guard contains icmp P(X, Y),
223 // find icmp !P(X, Y) in the action and replace with !isTrue.
224 if (auto guardIcmp = guardExpr.getDefiningOp<ICmpOp>()) {
225 if (auto actionIcmp = dyn_cast<ICmpOp>(op)) {
226 if (actionIcmp.getPredicate() ==
227 ICmpOp::getNegatedPredicate(guardIcmp.getPredicate()) &&
228 areStructurallyEquivalent(guardIcmp.getLhs(),
229 actionIcmp.getLhs(), guardRegion,
230 actionRegion) &&
231 areStructurallyEquivalent(guardIcmp.getRhs(),
232 actionIcmp.getRhs(), guardRegion,
233 actionRegion))
234 return replaceWithConstant(result, !isTrue, op, rewriter);
235 }
236 }
237 }
238 }
239 return failure();
240 }
241
242private:
243 LogicalResult replaceWithConstant(Value result, bool constVal, Operation *op,
244 PatternRewriter &rewriter) const {
245 rewriter.setInsertionPointToStart(&actionRegion.front());
246 Value c = hw::ConstantOp::create(rewriter, op->getLoc(),
247 rewriter.getI1Type(), constVal ? 1 : 0);
248 rewriter.replaceAllUsesWith(result, c);
249 if (op->use_empty())
250 rewriter.eraseOp(op);
251 return success();
252 }
253
254 SmallVector<std::pair<Value, bool>> guardFacts;
255 Region &guardRegion;
256 Region &actionRegion;
257};
258
259/// Simplify action blocks by propagating guard conditions.
260/// When a transition is taken, its guard condition is known to be true.
261/// This function finds expressions in the action region that are structurally
262/// identical to the guard condition and replaces them with constant true.
263/// It also handles NOT (xor with true) and AND decomposition to propagate
264/// additional known values.
265static void simplifyActionWithGuard(TransitionOp transition,
266 OpBuilder &builder) {
267 Region &guardRegion = transition.getGuard();
268 Region &actionRegion = transition.getAction();
269
270 if (guardRegion.empty() || actionRegion.empty())
271 return;
272
273 auto guardReturn =
274 dyn_cast<fsm::ReturnOp>(guardRegion.front().getTerminator());
275 if (!guardReturn)
276 return;
277
278 Location loc = guardReturn.getLoc();
279 Value guardCondition = guardReturn.getOperand();
280
281 // Collect (expression, isTrue) pairs to propagate into the action.
282 SmallVector<std::pair<Value, bool>> guardFacts;
283
284 // Use a worklist to decompose AND expressions.
285 SmallVector<std::pair<Value, bool>> worklist;
286 worklist.push_back({guardCondition, true});
287
288 while (!worklist.empty()) {
289 auto [cond, isTrue] = worklist.pop_back_val();
290 guardFacts.push_back({cond, isTrue});
291
292 // Decompose NOT: xor(E, true) means E has the opposite truth value.
293 if (auto xorOp = cond.getDefiningOp<XorOp>()) {
294 if (xorOp.isBinaryNot() &&
295 guardRegion.isAncestor(xorOp->getParentRegion()))
296 guardFacts.push_back({xorOp.getOperand(0), !isTrue});
297 }
298
299 // Decompose AND: if the AND is true, each operand is true.
300 if (isTrue) {
301 if (auto andOp = cond.getDefiningOp<AndOp>()) {
302 if (guardRegion.isAncestor(andOp->getParentRegion())) {
303 for (Value operand : andOp.getOperands())
304 worklist.push_back({operand, true});
305 }
306 }
307 }
308 }
309
310 // Replace external guard values (block args or ops outside guard region)
311 // directly via replaceUsesWithIf.
312 for (auto [guardExpr, isTrue] : guardFacts) {
313 bool guardIsExternal =
314 !guardExpr.getDefiningOp() ||
315 !guardRegion.isAncestor(guardExpr.getDefiningOp()->getParentRegion());
316 if (guardIsExternal) {
317 builder.setInsertionPointToStart(&actionRegion.front());
318 auto constOp = hw::ConstantOp::create(builder, loc, guardExpr.getType(),
319 isTrue ? 1 : 0);
320 guardExpr.replaceUsesWithIf(constOp.getResult(), [&](OpOperand &use) {
321 if (!actionRegion.isAncestor(use.getOwner()->getParentRegion()))
322 return false;
323 // Never replace the variable operand of fsm.update because that is the
324 // assignment target, not a value to be folded.
325 if (auto updateOp = dyn_cast<UpdateOp>(use.getOwner()))
326 if (use.getOperandNumber() == 0)
327 return false;
328 return true;
329 });
330 }
331 }
332
333 // Use pattern rewriter for internal guard expression folding.
334 MLIRContext *ctx = transition.getContext();
335 RewritePatternSet patterns(ctx);
336 patterns.add<GuardConditionFoldPattern>(ctx, guardFacts, guardRegion,
337 actionRegion);
338 SmallVector<Operation *> actionOps;
339 actionRegion.walk([&](Operation *op) { actionOps.push_back(op); });
340 GreedyRewriteConfig config;
341 config.setScope(&actionRegion);
342 (void)applyOpPatternsGreedily(
343 actionOps, FrozenRewritePatternSet(std::move(patterns)), config);
344}
345
346/// Checks if a value is a constant or a tree of muxes with constant leaves.
347/// Uses an iterative approach with a visited set to handle cycles.
348static bool isConstantOrConstantTree(Value value) {
349 SmallVector<Value> worklist;
350 llvm::DenseSet<Value> visited;
351
352 worklist.push_back(value);
353 while (!worklist.empty()) {
354 Value current = worklist.pop_back_val();
355
356 // Skip if already visited (handles cycles).
357 if (!visited.insert(current).second)
358 continue;
359
360 Operation *definingOp = current.getDefiningOp();
361 if (!definingOp)
362 return false;
363
364 if (isa<hw::ConstantOp>(definingOp))
365 continue;
366
367 if (auto muxOp = dyn_cast<MuxOp>(definingOp)) {
368 worklist.push_back(muxOp.getTrueValue());
369 worklist.push_back(muxOp.getFalseValue());
370 continue;
371 }
372
373 // Not a constant or mux - not constant-like.
374 return false;
375 }
376 return true;
377}
378
379/// Pushes an ICmp equality comparison through a mux operation.
380/// This transforms `icmp eq (mux cond, x, y), b` into
381/// `mux cond, (icmp eq x, b), (icmp eq y, b)`.
382/// This simplification helps expose constant comparisons that can be folded
383/// during FSM extraction, making transition guards easier to analyze.
384LogicalResult pushIcmp(ICmpOp op, PatternRewriter &rewriter) {
385 APInt lhs, rhs;
386 if (op.getPredicate() == ICmpPredicate::eq &&
387 op.getLhs().getDefiningOp<MuxOp>() &&
388 (isConstantOrConstantTree(op.getLhs()) ||
389 op.getRhs().getDefiningOp<hw::ConstantOp>())) {
390 rewriter.setInsertionPointAfter(op);
391 auto mux = op.getLhs().getDefiningOp<MuxOp>();
392 Value x = mux.getTrueValue();
393 Value y = mux.getFalseValue();
394 Value b = op.getRhs();
395 Location loc = op.getLoc();
396 auto eq1 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, x, b);
397 auto eq2 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, y, b);
398 rewriter.replaceOpWithNewOp<MuxOp>(op, mux.getCond(), eq1.getResult(),
399 eq2.getResult());
400 return llvm::success();
401 }
402 if (op.getPredicate() == ICmpPredicate::eq &&
403 op.getRhs().getDefiningOp<MuxOp>() &&
404 (isConstantOrConstantTree(op.getRhs()) ||
405 op.getLhs().getDefiningOp<hw::ConstantOp>())) {
406 rewriter.setInsertionPointAfter(op);
407 auto mux = op.getRhs().getDefiningOp<MuxOp>();
408 Value x = mux.getTrueValue();
409 Value y = mux.getFalseValue();
410 Value b = op.getLhs();
411 Location loc = op.getLoc();
412 auto eq1 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, x, b);
413 auto eq2 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, y, b);
414 rewriter.replaceOpWithNewOp<MuxOp>(op, mux.getCond(), eq1.getResult(),
415 eq2.getResult());
416 return llvm::success();
417 }
418 return llvm::failure();
419}
420
421/// Iteratively builds all possible concatenated integer values from the
422/// Cartesian product of value sets.
423static void generateConcatenatedValues(
424 const llvm::SmallVector<llvm::SetVector<size_t>> &allOperandValues,
425 const llvm::SmallVector<unsigned> &shifts,
426 llvm::SetVector<size_t> &finalPossibleValues) {
427
428 if (allOperandValues.empty()) {
429 finalPossibleValues.insert(0);
430 return;
431 }
432
433 // Start with the values of the first operand, shifted appropriately.
434 llvm::SetVector<size_t> currentResults;
435 for (size_t val : allOperandValues[0])
436 currentResults.insert(val << shifts[0]);
437
438 // For each subsequent operand, combine with all existing partial results.
439 for (size_t operandIdx = 1; operandIdx < allOperandValues.size();
440 ++operandIdx) {
441 llvm::SetVector<size_t> nextResults;
442 unsigned shift = shifts[operandIdx];
443
444 for (size_t partialValue : currentResults) {
445 for (size_t val : allOperandValues[operandIdx]) {
446 nextResults.insert(partialValue | (val << shift));
447 }
448 }
449 currentResults = std::move(nextResults);
450 }
451
452 for (size_t val : currentResults)
453 finalPossibleValues.insert(val);
454}
455
456static llvm::MapVector<Value, int> intToRegMap(SmallVector<seq::CompRegOp> v,
457 int i) {
459 for (size_t ci = 0; ci < v.size(); ci++) {
460 seq::CompRegOp reg = v[ci];
461 int bits = reg.getType().getIntOrFloatBitWidth();
462 int v = i & ((1 << bits) - 1);
463 m[reg] = v;
464 i = i >> bits;
465 }
466 return m;
467}
468
469static int regMapToInt(SmallVector<seq::CompRegOp> v,
470 llvm::DenseMap<Value, int> m) {
471 int i = 0;
472 int width = 0;
473 for (size_t ci = 0; ci < v.size(); ci++) {
474 seq::CompRegOp reg = v[ci];
475 i += m[reg] * 1ULL << width;
476 width += (reg.getType().getIntOrFloatBitWidth());
477 }
478 return i;
479}
480
481/// Computes the Cartesian product of a list of sets.
482static std::set<llvm::SmallVector<size_t>> calculateCartesianProduct(
483 const llvm::SmallVector<llvm::SetVector<size_t>> &valueSets) {
484 std::set<llvm::SmallVector<size_t>> product;
485 if (valueSets.empty()) {
486 // The Cartesian product of zero sets is a set containing one element:
487 // the empty tuple (represented here by an empty vector).
488 product.insert({});
489 return product;
490 }
491
492 // Initialize the product with the elements of the first set, each in its
493 // own vector.
494 for (size_t value : valueSets.front()) {
495 product.insert({value});
496 }
497
498 // Iteratively build the product. For each subsequent set, create a new
499 // temporary product by appending each of its elements to every combination
500 // already generated.
501 for (size_t i = 1; i < valueSets.size(); ++i) {
502 const auto &currentSet = valueSets[i];
503 if (currentSet.empty()) {
504 // The Cartesian product with an empty set results in an empty set.
505 return {};
506 }
507
508 std::set<llvm::SmallVector<size_t>> newProduct;
509 for (const auto &existingVector : product) {
510 for (size_t newValue : currentSet) {
511 llvm::SmallVector<size_t> newVector = existingVector;
512 newVector.push_back(newValue);
513 newProduct.insert(std::move(newVector));
514 }
515 }
516 product = std::move(newProduct);
517 }
518
519 return product;
520}
521
522static FrozenRewritePatternSet loadPatterns(MLIRContext &context) {
523
524 RewritePatternSet patterns(&context);
525 for (auto *dialect : context.getLoadedDialects())
526 dialect->getCanonicalizationPatterns(patterns);
527 ICmpOp::getCanonicalizationPatterns(patterns, &context);
528 AndOp::getCanonicalizationPatterns(patterns, &context);
529 XorOp::getCanonicalizationPatterns(patterns, &context);
530 MuxOp::getCanonicalizationPatterns(patterns, &context);
531 ConcatOp::getCanonicalizationPatterns(patterns, &context);
532 ExtractOp::getCanonicalizationPatterns(patterns, &context);
533 AddOp::getCanonicalizationPatterns(patterns, &context);
534 OrOp::getCanonicalizationPatterns(patterns, &context);
535 MulOp::getCanonicalizationPatterns(patterns, &context);
536 hw::ConstantOp::getCanonicalizationPatterns(patterns, &context);
537 TransitionOp::getCanonicalizationPatterns(patterns, &context);
538 StateOp::getCanonicalizationPatterns(patterns, &context);
539 MachineOp::getCanonicalizationPatterns(patterns, &context);
540 patterns.add(pushIcmp);
541 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
542 return frozenPatterns;
543}
544
545static LogicalResult
546getReachableStates(llvm::SetVector<size_t> &visitableStates,
547 HWModuleOp moduleOp, size_t currentStateIndex,
548 SmallVector<seq::CompRegOp> registers) {
549
550 // Clone into an isolated mlir::ModuleOp to avoid violating symbol uniqueness
551 // verifier conditions (which would violate pattern invariants)
552 mlir::OwningOpRef<mlir::ModuleOp> analysisModule =
553 mlir::ModuleOp::create(moduleOp.getLoc());
554 OpBuilder b(moduleOp.getContext());
555 b.setInsertionPointToStart(analysisModule->getBody());
556
557 IRMapping mapping;
558 auto clonedBody = llvm::dyn_cast<HWModuleOp>(b.clone(*moduleOp, mapping));
559
561 intToRegMap(registers, currentStateIndex);
562 Operation *terminator = clonedBody.getBody().front().getTerminator();
563 auto output = dyn_cast<hw::OutputOp>(terminator);
564 SmallVector<Value> values;
565
566 for (auto [originalRegValue, constStateValue] : stateMap) {
567
568 Value clonedRegValue = mapping.lookup(originalRegValue);
569 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
570 auto reg = cast<seq::CompRegOp>(clonedRegOp);
571 Type constantType = reg.getType();
572 IntegerAttr constantAttr = b.getIntegerAttr(constantType, constStateValue);
573 b.setInsertionPoint(clonedRegOp);
574 auto otherStateConstant =
575 hw::ConstantOp::create(b, reg.getLoc(), constantAttr);
576 // If the register input is self-referential (input == output), use the
577 // constant we're replacing it with. Otherwise, the value would become
578 // dangling after we erase the register.
579 Value regInput = reg.getInput();
580 if (regInput == clonedRegValue)
581 values.push_back(otherStateConstant.getResult());
582 else
583 values.push_back(regInput);
584 clonedRegValue.replaceAllUsesWith(otherStateConstant.getResult());
585 reg.erase();
586 }
587 b.setInsertionPointToEnd(clonedBody.front().getBlock());
588 auto newOutput = hw::OutputOp::create(b, output.getLoc(), values);
589 output.erase();
590
591 // Update the module type to match the new hw.output operands to avoid
592 // violating verifier conditions
593 SmallVector<hw::ModulePort> newPorts;
594 for (hw::ModulePort p : clonedBody.getHWModuleType().getPorts())
595 if (p.dir == hw::ModulePort::Direction::Input)
596 newPorts.push_back(p);
597 for (auto [i, val] : llvm::enumerate(values))
598 newPorts.push_back({b.getStringAttr("out" + std::to_string(i)),
599 val.getType(), hw::ModulePort::Direction::Output});
600 clonedBody.setHWModuleType(hw::ModuleType::get(b.getContext(), newPorts));
601
602 FrozenRewritePatternSet frozenPatterns = loadPatterns(*moduleOp.getContext());
603
604 SmallVector<Operation *> opsToProcess;
605 clonedBody.walk([&](Operation *op) { opsToProcess.push_back(op); });
606
607 bool changed = false;
608 GreedyRewriteConfig config;
609 if (failed(applyOpPatternsGreedily(opsToProcess, frozenPatterns, config,
610 &changed)))
611 return failure();
612
613 llvm::SmallVector<llvm::SetVector<size_t>> pv;
614 for (size_t j = 0; j < newOutput.getNumOperands(); j++) {
615 llvm::SetVector<size_t> possibleValues;
616
617 Value v = newOutput.getOperand(j);
618 addPossibleValues(possibleValues, v);
619 pv.push_back(possibleValues);
620 }
621 std::set<llvm::SmallVector<size_t>> flipped = calculateCartesianProduct(pv);
622 for (llvm::SmallVector<size_t> v : flipped) {
623 llvm::DenseMap<Value, int> m;
624 for (size_t k = 0; k < v.size(); k++) {
625 seq::CompRegOp r = registers[k];
626 m[r] = v[k];
627 }
628
629 int i = regMapToInt(registers, m);
630 visitableStates.insert(i);
631 }
632
633 // Cloned body is destroyed when analysisModule goes out of scope
634 return success();
635}
636
637// A converter class to handle the logic of converting a single hw.module.
638class HWModuleOpConverter {
639public:
640 HWModuleOpConverter(OpBuilder &builder, HWModuleOp moduleOp,
641 ArrayRef<std::string> stateRegNames)
642 : moduleOp(moduleOp), opBuilder(builder), stateRegNames(stateRegNames) {}
643 LogicalResult run() {
644 SmallVector<seq::CompRegOp> stateRegs;
645 SmallVector<seq::CompRegOp> variableRegs;
646 Value foundClock, foundReset = nullptr;
647 WalkResult walkResult = moduleOp.walk([&](seq::CompRegOp reg) {
648 auto clk = reg.getClk();
649 auto reset = reg.getReset();
650 if (foundClock) {
651 if (clk != foundClock) {
652 reg.emitError("All registers must have the same clock signal.");
653 return WalkResult::interrupt();
654 }
655 } else {
656 foundClock = clk;
657 }
658
659 if (reset) {
660 if (foundReset) {
661 if (reset != foundReset) {
662 reg.emitError("All registers must have the same reset signal.");
663 return WalkResult::interrupt();
664 }
665 } else {
666 foundReset = reset;
667 }
668 }
669
670 // Check that the register type is an integer.
671 if (!isa<IntegerType>(reg.getType())) {
672 reg.emitError("FSM extraction only supports integer-typed registers");
673 return WalkResult::interrupt();
674 }
675 if (isStateRegister(reg)) {
676 stateRegs.push_back(reg);
677 } else {
678 variableRegs.push_back(reg);
679 }
680 return WalkResult::advance();
681 });
682 if (walkResult.wasInterrupted())
683 return failure();
684 if (stateRegs.empty()) {
685 emitError(moduleOp.getLoc())
686 << "Cannot find state register in this FSM. Use the state-regs "
687 "option to specify which registers are state registers.";
688 return failure();
689 }
690 SmallVector<seq::CompRegOp> registers;
691 for (seq::CompRegOp c : stateRegs) {
692 registers.push_back(c);
693 }
694
695 llvm::DenseMap<size_t, StateOp> stateToStateOp;
696 llvm::DenseMap<StateOp, size_t> stateOpToState;
697 // Collect reset arguments to exclude from the FSM's function type.
698 // All CompReg reset signals are ignored during FSM extraction since the
699 // FSM dialect does not have an explicit reset concept. The reset behavior
700 // is only captured in the initial state value.
701 llvm::DenseSet<size_t> asyncResetArguments;
702 Location loc = moduleOp.getLoc();
703 SmallVector<Type> inputTypes = moduleOp.getInputTypes();
704
705 // Create a new FSM machine with the current state.
706 auto resultTypes = moduleOp.getOutputTypes();
707 FunctionType machineType =
708 FunctionType::get(opBuilder.getContext(), inputTypes, resultTypes);
709 StringRef machineName = moduleOp.getName();
710
711 llvm::DenseMap<Value, int> initialStateMap;
712 for (seq::CompRegOp reg : moduleOp.getOps<seq::CompRegOp>()) {
713 Value resetValue = reg.getResetValue();
714 hw::ConstantOp definingConstant;
715 if (resetValue) {
716 definingConstant = resetValue.getDefiningOp<hw::ConstantOp>();
717 } else {
718 // Assume that registers without a reset start at 0
719 reg.emitWarning("Assuming register with no reset starts with value 0");
720 definingConstant =
721 hw::ConstantOp::create(opBuilder, reg.getLoc(), reg.getType(), 0);
722 }
723 if (!definingConstant) {
724 reg->emitError(
725 "cannot find defining constant for reset value of register");
726 return failure();
727 }
728 int resetValueInt =
729 definingConstant.getValueAttr().getValue().getZExtValue();
730 initialStateMap[reg] = resetValueInt;
731 }
732 int initialStateIndex = regMapToInt(registers, initialStateMap);
733
734 std::string initialStateName = "state_" + std::to_string(initialStateIndex);
735
736 // Preserve argument and result names, which are stored as attributes.
737 SmallVector<NamedAttribute> machineAttrs;
738 if (auto argNames = moduleOp->getAttrOfType<ArrayAttr>("argNames"))
739 machineAttrs.emplace_back(opBuilder.getStringAttr("argNames"), argNames);
740 if (auto resNames = moduleOp->getAttrOfType<ArrayAttr>("resultNames"))
741 machineAttrs.emplace_back(opBuilder.getStringAttr("resNames"), resNames);
742
743 // The builder for fsm.MachineOp will create the body region and block
744 // arguments.
745 opBuilder.setInsertionPoint(moduleOp);
746 auto machine =
747 MachineOp::create(opBuilder, loc, machineName, initialStateName,
748 machineType, machineAttrs);
749
750 OpBuilder::InsertionGuard guard(opBuilder);
751 opBuilder.setInsertionPointToStart(&machine.getBody().front());
753 for (seq::CompRegOp varReg : variableRegs) {
754 TypedValue<Type> initialValue = varReg.getResetValue();
755 hw::ConstantOp definingConstant;
756 if (initialValue) {
757 definingConstant = initialValue.getDefiningOp<hw::ConstantOp>();
758 } else {
759 // Assume that registers without a reset start at 0
760 varReg.emitWarning(
761 "Assuming register with no reset starts with value 0");
762 definingConstant = hw::ConstantOp::create(opBuilder, varReg.getLoc(),
763 varReg.getType(), 0);
764 }
765 if (!definingConstant) {
766 varReg->emitError("cannot find defining constant for reset value of "
767 "variable register");
768 return failure();
769 }
770 auto variableOp = VariableOp::create(
771 opBuilder, varReg->getLoc(), varReg.getInput().getType(),
772 definingConstant.getValueAttr(), varReg.getName().value_or("var"));
773 variableMap[varReg] = variableOp;
774 }
775
776 // Load rewrite patterns used for canonicalizing the generated FSM.
777 FrozenRewritePatternSet frozenPatterns =
778 loadPatterns(*moduleOp.getContext());
779
780 SetVector<int> reachableStates;
781 SmallVector<int> worklist;
782
783 worklist.push_back(initialStateIndex);
784 reachableStates.insert(initialStateIndex);
785 // Process states in BFS order. The worklist grows as new reachable states
786 // are discovered, so we use an index-based loop.
787 for (unsigned i = 0; i < worklist.size(); ++i) {
788
789 int currentStateIndex = worklist[i];
790
792 intToRegMap(registers, currentStateIndex);
793
794 opBuilder.setInsertionPointToEnd(&machine.getBody().front());
795
796 StateOp stateOp;
797
798 if (!stateToStateOp.contains(currentStateIndex)) {
799 stateOp = StateOp::create(opBuilder, loc,
800 "state_" + std::to_string(currentStateIndex));
801 stateToStateOp.insert({currentStateIndex, stateOp});
802 stateOpToState.insert({stateOp, currentStateIndex});
803 } else {
804 stateOp = stateToStateOp.lookup(currentStateIndex);
805 }
806 Region &outputRegion = stateOp.getOutput();
807 Block *outputBlock = &outputRegion.front();
808 opBuilder.setInsertionPointToStart(outputBlock);
809 IRMapping mapping;
810 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), outputRegion,
811 outputBlock->getIterator(), mapping);
812 outputBlock->erase();
813
814 auto *terminator = outputRegion.front().getTerminator();
815 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
816 assert(hwOutputOp && "Expected terminator to be hw.output op");
817
818 // Position the builder to insert the new terminator right before the
819 // old one.
820 OpBuilder::InsertionGuard stateGuard(opBuilder);
821 opBuilder.setInsertionPoint(hwOutputOp);
822
823 // Create the new fsm.OutputOp with the same operands.
824
825 fsm::OutputOp::create(opBuilder, hwOutputOp.getLoc(),
826 hwOutputOp.getOperands());
827
828 // Erase the old terminator.
829 hwOutputOp.erase();
830
831 // Iterate through the state configuration to replace registers
832 // with constants.
833 for (auto &[originalRegValue, variableOp] : variableMap) {
834 Value clonedRegValue = mapping.lookup(originalRegValue);
835 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
836 auto reg = cast<seq::CompRegOp>(clonedRegOp);
837 const auto res = variableOp.getResult();
838 clonedRegValue.replaceAllUsesWith(res);
839 reg.erase();
840 }
841 for (auto const &[originalRegValue, constStateValue] : stateMap) {
842 // Find the cloned register's result value using the mapping.
843 Value clonedRegValue = mapping.lookup(originalRegValue);
844 assert(clonedRegValue &&
845 "Original register value not found in mapping");
846 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
847
848 assert(clonedRegOp && "Cloned value must have a defining op");
849 opBuilder.setInsertionPoint(clonedRegOp);
850 auto r = cast<seq::CompRegOp>(clonedRegOp);
851 TypedValue<IntegerType> registerReset = r.getReset();
852 if (registerReset) {
853 if (BlockArgument blockArg = dyn_cast<BlockArgument>(registerReset)) {
854 asyncResetArguments.insert(blockArg.getArgNumber());
855 auto falseConst = hw::ConstantOp::create(
856 opBuilder, blockArg.getLoc(), clonedRegValue.getType(), 0);
857 blockArg.replaceAllUsesWith(falseConst.getResult());
858 }
859 if (auto xorOp = registerReset.getDefiningOp<XorOp>()) {
860 if (xorOp.isBinaryNot()) {
861 Value rhs = xorOp.getOperand(0);
862 if (BlockArgument blockArg = dyn_cast<BlockArgument>(rhs)) {
863 asyncResetArguments.insert(blockArg.getArgNumber());
864 auto trueConst = hw::ConstantOp::create(
865 opBuilder, blockArg.getLoc(), blockArg.getType(), 1);
866 blockArg.replaceAllUsesWith(trueConst.getResult());
867 }
868 }
869 }
870 }
871 auto constantOp =
872 hw::ConstantOp::create(opBuilder, clonedRegValue.getLoc(),
873 clonedRegValue.getType(), constStateValue);
874 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
875 clonedRegOp->erase();
876 }
877 GreedyRewriteConfig config;
878 SmallVector<Operation *> opsToProcess;
879 outputRegion.walk([&](Operation *op) { opsToProcess.push_back(op); });
880 // Replace references to arguments in the output block with
881 // arguments at the top level.
882 for (auto arg : outputRegion.front().getArguments()) {
883 int argIndex = arg.getArgNumber();
884 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
885 arg.replaceAllUsesWith(topLevelArg);
886 }
887 outputRegion.front().eraseArguments(
888 [](BlockArgument arg) { return true; });
889 // hw.module uses graph regions that allow cycles (e.g., registers feeding
890 // back into themselves). By this point we've replaced all registers with
891 // constants, but cycles in purely combinational logic (e.g., cyclic
892 // muxes) may still exist. Such cycles cannot be converted to FSM.
893 bool sorted = sortTopologically(&outputRegion.front());
894 if (!sorted) {
895 moduleOp.emitError()
896 << "cannot convert module with combinational cycles to FSM";
897 return failure();
898 }
899 FrozenRewritePatternSet patterns(opBuilder.getContext());
900 config.setScope(&outputRegion);
901
902 bool changed = false;
903 if (failed(applyOpPatternsGreedily(opsToProcess, patterns, config,
904 &changed)))
905 return failure();
906 opBuilder.setInsertionPoint(stateOp);
907 Region &transitionRegion = stateOp.getTransitions();
908 llvm::SetVector<size_t> visitableStates;
909 if (failed(getReachableStates(visitableStates, moduleOp,
910 currentStateIndex, registers)))
911 return failure();
912 for (size_t j : visitableStates) {
913 StateOp toState;
914 if (!stateToStateOp.contains(j)) {
915 opBuilder.setInsertionPointToEnd(&machine.getBody().front());
916 toState =
917 StateOp::create(opBuilder, loc, "state_" + std::to_string(j));
918 stateToStateOp.insert({j, toState});
919 stateOpToState.insert({toState, j});
920 } else {
921 toState = stateToStateOp[j];
922 }
923 opBuilder.setInsertionPointToStart(&transitionRegion.front());
924 auto transitionOp =
925 TransitionOp::create(opBuilder, loc, "state_" + std::to_string(j));
926 Region &guardRegion = transitionOp.getGuard();
927 opBuilder.createBlock(&guardRegion);
928
929 Block &guardBlock = guardRegion.front();
930
931 opBuilder.setInsertionPointToStart(&guardBlock);
932 IRMapping mapping;
933 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), guardRegion,
934 guardBlock.getIterator(), mapping);
935 guardBlock.erase();
936 Block &newGuardBlock = guardRegion.front();
937 Operation *terminator = newGuardBlock.getTerminator();
938 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
939 assert(hwOutputOp && "Expected terminator to be hw.output op");
940
941 llvm::MapVector<Value, int> toStateMap = intToRegMap(registers, j);
942 SmallVector<Value> equalityChecks;
943 for (auto &[originalRegValue, variableOp] : variableMap) {
944 opBuilder.setInsertionPointToStart(&newGuardBlock);
945 Value clonedRegValue = mapping.lookup(originalRegValue);
946 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
947 auto reg = cast<seq::CompRegOp>(clonedRegOp);
948 const auto res = variableOp.getResult();
949 clonedRegValue.replaceAllUsesWith(res);
950 reg.erase();
951 }
952 for (auto const &[originalRegValue, constStateValue] : toStateMap) {
953
954 Value clonedRegValue = mapping.lookup(originalRegValue);
955 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
956 opBuilder.setInsertionPoint(clonedRegOp);
957 auto r = cast<seq::CompRegOp>(clonedRegOp);
958
959 Value registerInput = r.getInput();
960 TypedValue<IntegerType> registerReset = r.getReset();
961 if (registerReset) {
962 if (BlockArgument blockArg =
963 dyn_cast<BlockArgument>(registerReset)) {
964 auto falseConst = hw::ConstantOp::create(
965 opBuilder, blockArg.getLoc(), clonedRegValue.getType(), 0);
966 blockArg.replaceAllUsesWith(falseConst.getResult());
967 }
968 if (auto xorOp = registerReset.getDefiningOp<XorOp>()) {
969 if (xorOp.isBinaryNot()) {
970 Value rhs = xorOp.getOperand(0);
971 if (BlockArgument blockArg = dyn_cast<BlockArgument>(rhs)) {
972 auto trueConst = hw::ConstantOp::create(
973 opBuilder, blockArg.getLoc(), blockArg.getType(), 1);
974 blockArg.replaceAllUsesWith(trueConst.getResult());
975 }
976 }
977 }
978 }
979 Type constantType = registerInput.getType();
980 IntegerAttr constantAttr =
981 opBuilder.getIntegerAttr(constantType, constStateValue);
982 auto otherStateConstant = hw::ConstantOp::create(
983 opBuilder, hwOutputOp.getLoc(), constantAttr);
984
985 auto doesEqual =
986 ICmpOp::create(opBuilder, hwOutputOp.getLoc(), ICmpPredicate::eq,
987 registerInput, otherStateConstant.getResult());
988 equalityChecks.push_back(doesEqual.getResult());
989 }
990 opBuilder.setInsertionPoint(hwOutputOp);
991 auto allEqualCheck = AndOp::create(opBuilder, hwOutputOp.getLoc(),
992 equalityChecks, false);
993 fsm::ReturnOp::create(opBuilder, hwOutputOp.getLoc(),
994 allEqualCheck.getResult());
995 hwOutputOp.erase();
996 for (BlockArgument arg : newGuardBlock.getArguments()) {
997 int argIndex = arg.getArgNumber();
998 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
999 arg.replaceAllUsesWith(topLevelArg);
1000 }
1001 newGuardBlock.eraseArguments([](BlockArgument arg) { return true; });
1002 llvm::MapVector<Value, int> fromStateMap =
1003 intToRegMap(registers, currentStateIndex);
1004 for (auto const &[originalRegValue, constStateValue] : fromStateMap) {
1005 Value clonedRegValue = mapping.lookup(originalRegValue);
1006 assert(clonedRegValue &&
1007 "Original register value not found in mapping");
1008 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
1009 assert(clonedRegOp && "Cloned value must have a defining op");
1010 opBuilder.setInsertionPoint(clonedRegOp);
1011 auto constantOp =
1012 hw::ConstantOp::create(opBuilder, clonedRegValue.getLoc(),
1013 clonedRegValue.getType(), constStateValue);
1014 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
1015 clonedRegOp->erase();
1016 }
1017 // Sort before running patterns to avoid violating dominance and
1018 // therefore pattern invariants
1019 bool guardSorted = sortTopologically(&newGuardBlock);
1020 if (!guardSorted) {
1021 moduleOp.emitError()
1022 << "cannot convert module with combinational cycles to FSM";
1023 return failure();
1024 }
1025 Region &actionRegion = transitionOp.getAction();
1026 if (!variableRegs.empty()) {
1027 Block *actionBlock = opBuilder.createBlock(&actionRegion);
1028 opBuilder.setInsertionPointToStart(actionBlock);
1029 IRMapping mapping;
1030 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), actionRegion,
1031 actionBlock->getIterator(), mapping);
1032 actionBlock->erase();
1033 Block &newActionBlock = actionRegion.front();
1034 for (BlockArgument arg : newActionBlock.getArguments()) {
1035 int argIndex = arg.getArgNumber();
1036 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
1037 arg.replaceAllUsesWith(topLevelArg);
1038 }
1039 newActionBlock.eraseArguments([](BlockArgument arg) { return true; });
1040 for (auto &[originalRegValue, variableOp] : variableMap) {
1041 Value clonedRegValue = mapping.lookup(originalRegValue);
1042 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
1043 auto reg = cast<seq::CompRegOp>(clonedRegOp);
1044 opBuilder.setInsertionPointToStart(&newActionBlock);
1045 UpdateOp::create(opBuilder, reg.getLoc(), variableOp,
1046 reg.getInput());
1047 const Value res = variableOp.getResult();
1048 clonedRegValue.replaceAllUsesWith(res);
1049 reg.erase();
1050 }
1051 Operation *terminator = actionRegion.back().getTerminator();
1052 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
1053 assert(hwOutputOp && "Expected terminator to be hw.output op");
1054 hwOutputOp.erase();
1055
1056 for (auto const &[originalRegValue, constStateValue] : fromStateMap) {
1057 Value clonedRegValue = mapping.lookup(originalRegValue);
1058 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
1059 opBuilder.setInsertionPoint(clonedRegOp);
1060 auto constantOp = hw::ConstantOp::create(
1061 opBuilder, clonedRegValue.getLoc(), clonedRegValue.getType(),
1062 constStateValue);
1063 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
1064 clonedRegOp->erase();
1065 }
1066
1067 // Sort before running patterns to avoid violating dominance and
1068 // therefore pattern invariants
1069 bool actionSorted = sortTopologically(&actionRegion.front());
1070 if (!actionSorted) {
1071 moduleOp.emitError()
1072 << "cannot convert module with combinational cycles to FSM";
1073 return failure();
1074 }
1075
1076 GreedyRewriteConfig config;
1077 SmallVector<Operation *> opsToProcess;
1078 actionRegion.walk([&](Operation *op) { opsToProcess.push_back(op); });
1079 config.setScope(&actionRegion);
1080
1081 bool changed = false;
1082 if (failed(applyOpPatternsGreedily(opsToProcess, frozenPatterns,
1083 config, &changed)))
1084 return failure();
1085 }
1086
1087 SmallVector<Operation *> outputOps;
1088 stateOp.getOutput().walk(
1089 [&](Operation *op) { outputOps.push_back(op); });
1090
1091 bool changed = false;
1092 GreedyRewriteConfig config;
1093 config.setScope(&stateOp.getOutput());
1094 LogicalResult converged = applyOpPatternsGreedily(
1095 outputOps, frozenPatterns, config, &changed);
1096 assert(succeeded(converged) && "canonicalization failed to converge");
1097 SmallVector<Operation *> transitionOps;
1098 stateOp.getTransitions().walk(
1099 [&](Operation *op) { transitionOps.push_back(op); });
1100
1101 GreedyRewriteConfig config2;
1102 config2.setScope(&stateOp.getTransitions());
1103 if (failed(applyOpPatternsGreedily(transitionOps, frozenPatterns,
1104 config2, &changed))) {
1105 return failure();
1106 }
1107
1108 // Propagate guard conditions into action blocks to eliminate
1109 // redundant muxes. The guard already checks the transition
1110 // condition, so the action block can assume it holds.
1111 for (TransitionOp transition :
1112 stateOp.getTransitions().getOps<TransitionOp>())
1113 simplifyActionWithGuard(transition, opBuilder);
1114
1115 // Re-canonicalize after guard propagation to clean up dead ops.
1116 {
1117 SmallVector<Operation *> postOps;
1118 stateOp.getTransitions().walk(
1119 [&](Operation *op) { postOps.push_back(op); });
1120 GreedyRewriteConfig postConfig;
1121 postConfig.setScope(&stateOp.getTransitions());
1122 if (failed(applyOpPatternsGreedily(postOps, frozenPatterns,
1123 postConfig, &changed)))
1124 return failure();
1125 }
1126
1127 for (TransitionOp transition :
1128 stateOp.getTransitions().getOps<TransitionOp>()) {
1129 StateOp nextState = transition.getNextStateOp();
1130 int nextStateIndex = stateOpToState.lookup(nextState);
1131 auto guardConst = transition.getGuardReturn()
1132 .getOperand()
1133 .getDefiningOp<hw::ConstantOp>();
1134 bool nextStateIsReachable =
1135 !guardConst || (guardConst.getValueAttr().getInt() != 0);
1136 // If we find a valid next state and haven't seen it before, add it to
1137 // the worklist and the set of reachable states.
1138 if (nextStateIsReachable &&
1139 !reachableStates.contains(nextStateIndex)) {
1140 worklist.push_back(nextStateIndex);
1141 reachableStates.insert(nextStateIndex);
1142 }
1143 }
1144 }
1145 }
1146
1147 // Clean up unreachable states. States without an output region are
1148 // placeholder states that were created during reachability analysis but
1149 // never populated (i.e., they are unreachable from the initial state).
1150 SmallVector<StateOp> statesToErase;
1151
1152 // Collect unreachable states (those without an output op).
1153 for (StateOp stateOp : machine.getOps<StateOp>()) {
1154 if (!stateOp.getOutputOp()) {
1155 statesToErase.push_back(stateOp);
1156 }
1157 }
1158
1159 // Erase states in a separate loop to avoid iterator invalidation. We first
1160 // collect all states to erase, then iterate over that list. This is
1161 // necessary because erasing a state while iterating over machine.getOps()
1162 // would invalidate the iterator.
1163 for (StateOp stateOp : statesToErase) {
1164 for (TransitionOp transition : machine.getOps<TransitionOp>()) {
1165 if (transition.getNextStateOp().getSymName() == stateOp.getSymName()) {
1166 transition.erase();
1167 }
1168 }
1169 stateOp.erase();
1170 }
1171
1172 llvm::DenseSet<BlockArgument> asyncResetBlockArguments;
1173 for (auto arg : machine.getBody().front().getArguments()) {
1174 if (asyncResetArguments.contains(arg.getArgNumber())) {
1175 asyncResetBlockArguments.insert(arg);
1176 }
1177 }
1178
1179 // Emit a warning if reset signals were detected and removed.
1180 // The FSM dialect does not support reset signals, so the reset behavior
1181 // is only captured in the initial state. The original reset triggering
1182 // mechanism is not preserved.
1183 if (!asyncResetBlockArguments.empty()) {
1184 moduleOp.emitWarning()
1185 << "reset signals detected and removed from FSM; "
1186 "reset behavior is captured only in the initial state";
1187 }
1188
1189 Block &front = machine.getBody().front();
1190 front.eraseArguments([&](BlockArgument arg) {
1191 return asyncResetBlockArguments.contains(arg);
1192 });
1193
1194 if (llvm::any_of(front.getArguments(), [](BlockArgument arg) {
1195 return arg.getType() == seq::ClockType::get(arg.getContext()) &&
1196 arg.hasNUsesOrMore(1);
1197 })) {
1198 moduleOp.emitError("Clock uses outside register clocking are not "
1199 "currently supported.");
1200 return failure();
1201 }
1202 machine.getBody().front().eraseArguments([&](BlockArgument arg) {
1203 return arg.getType() == seq::ClockType::get(arg.getContext());
1204 });
1205 FunctionType oldFunctionType = machine.getFunctionType();
1206 SmallVector<Type> inputsWithoutClock;
1207 for (unsigned int i = 0; i < oldFunctionType.getNumInputs(); i++) {
1208 Type input = oldFunctionType.getInput(i);
1209 if (input != seq::ClockType::get(input.getContext()) &&
1210 !asyncResetArguments.contains(i))
1211 inputsWithoutClock.push_back(input);
1212 }
1213
1214 FunctionType newFunctionType = FunctionType::get(
1215 opBuilder.getContext(), inputsWithoutClock, resultTypes);
1216
1217 machine.setFunctionType(newFunctionType);
1218 moduleOp.erase();
1219 return success();
1220 }
1221
1222private:
1223 /// Helper function to determine if a register is a state register.
1224 bool isStateRegister(seq::CompRegOp reg) const {
1225 auto regName = reg.getName();
1226 if (!regName)
1227 return false;
1228
1229 // If user specified state registers, check if this register's name matches
1230 // any of them.
1231 if (!stateRegNames.empty()) {
1232 return llvm::is_contained(stateRegNames, regName->str());
1233 }
1234
1235 // Default behavior: infer state registers by checking if the name contains
1236 // "state".
1237 return regName->contains("state");
1238 }
1239
1240 HWModuleOp moduleOp;
1241 OpBuilder &opBuilder;
1242 ArrayRef<std::string> stateRegNames;
1243};
1244
1245} // namespace
1246
1247namespace {
1248struct CoreToFSMPass : public circt::impl::ConvertCoreToFSMBase<CoreToFSMPass> {
1249 using ConvertCoreToFSMBase<CoreToFSMPass>::ConvertCoreToFSMBase;
1250
1251 void runOnOperation() override {
1252 auto module = getOperation();
1253 OpBuilder builder(module);
1254
1255 SmallVector<HWModuleOp> modules;
1256 for (auto hwModule : module.getOps<HWModuleOp>()) {
1257 modules.push_back(hwModule);
1258 }
1259
1260 // Check for hw.instance operations - instance conversion is not supported.
1261 for (auto hwModule : modules) {
1262 for (auto instance : hwModule.getOps<hw::InstanceOp>()) {
1263 instance.emitError() << "instance conversion is not yet supported";
1264 signalPassFailure();
1265 return;
1266 }
1267 }
1268
1269 for (auto hwModule : modules) {
1270 builder.setInsertionPoint(hwModule);
1271 HWModuleOpConverter converter(builder, hwModule, stateRegs);
1272 if (failed(converter.run())) {
1273 signalPassFailure();
1274 return;
1275 }
1276 }
1277 }
1278};
1279} // namespace
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
@ Input
Definition HW.h:42
create(*operands)
Definition fsm.py:160
create(name)
Definition fsm.py:144
create(to_state)
Definition fsm.py:110
create(data_type, value)
Definition hw.py:433
Direction
The direction of a Component or Cell port.
Definition CalyxOps.h:76
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition comb.py:1
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition codegen.py:1716
Definition fsm.py:1
Definition hw.py:1
Definition seq.py:1
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21