CIRCT 23.0.0git
Loading...
Searching...
No Matches
CoreToFSM.cpp
Go to the documentation of this file.
1//===CoreToFSM.cpp - Convert Core Dialects (HW + Seq + Comb) to FSM Dialect===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17#include "circt/Support/LLVM.h"
18#include "mlir/Analysis/TopologicalSortUtils.h"
19#include "mlir/IR/Block.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinAttributes.h"
22#include "mlir/IR/Diagnostics.h"
23#include "mlir/IR/IRMapping.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/IR/Value.h"
26#include "mlir/Pass/Pass.h"
27#include "mlir/Pass/PassManager.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#include "mlir/Transforms/Passes.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/DenseSet.h"
33#include "llvm/ADT/SmallVector.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/LogicalResult.h"
36#include <string>
37
38namespace circt {
39#define GEN_PASS_DEF_CONVERTCORETOFSM
40#include "circt/Conversion/Passes.h.inc"
41} // namespace circt
42
43using namespace mlir;
44using namespace circt;
45using namespace hw;
46using namespace comb;
47using namespace fsm;
48
49namespace {
50
51// Forward declaration for our helper function
52static void generateConcatenatedValues(
53 const llvm::SmallVector<llvm::SetVector<size_t>> &allOperandValues,
54 const llvm::SmallVector<unsigned> &shifts,
55 llvm::SetVector<size_t> &finalPossibleValues);
56
57/// Internal helper with visited set to detect cycles.
58static void addPossibleValuesImpl(llvm::SetVector<size_t> &possibleValues,
59 Value v, llvm::DenseSet<Value> &visited) {
60 // Detect cycles - if we've seen this value before, skip it.
61 if (!visited.insert(v).second)
62 return;
63
64 if (auto c = dyn_cast_or_null<hw::ConstantOp>(v.getDefiningOp())) {
65 possibleValues.insert(c.getValueAttr().getValue().getZExtValue());
66 return;
67 }
68 if (auto m = dyn_cast_or_null<MuxOp>(v.getDefiningOp())) {
69 addPossibleValuesImpl(possibleValues, m.getTrueValue(), visited);
70 addPossibleValuesImpl(possibleValues, m.getFalseValue(), visited);
71 return;
72 }
73
74 if (auto concatOp = dyn_cast_or_null<ConcatOp>(v.getDefiningOp())) {
75 llvm::SmallVector<llvm::SetVector<size_t>> allOperandValues;
76 llvm::SmallVector<unsigned> operandWidths;
77
78 for (Value operand : concatOp.getOperands()) {
79 llvm::SetVector<size_t> operandPossibleValues;
80 addPossibleValuesImpl(operandPossibleValues, operand, visited);
81
82 // It's crucial to handle the case where a sub-computation is too complex.
83 // If we can't determine specific values for an operand, we must
84 // pessimistically assume it can be any value its bitwidth allows.
85 auto opType = dyn_cast<IntegerType>(operand.getType());
86 // comb.concat only accepts signless integer operands by definition in
87 // CIRCT's type system, so this assertion should always hold for valid IR.
88 assert(opType && "comb.concat operand must be an integer type");
89 unsigned width = opType.getWidth();
90 if (operandPossibleValues.empty()) {
91 uint64_t numStates = 1ULL << width;
92 // Add a threshold to prevent combinatorial explosion on large unknown
93 // inputs.
94 if (numStates > 256) { // Heuristic threshold
95 // If the search space is too large, we abandon the analysis for this
96 // path. The outer function will fall back to its own full-range
97 // default.
98 v.getDefiningOp()->emitWarning()
99 << "Search space too large (>" << 256
100 << " states) for operand with bitwidth " << width
101 << "; abandoning analysis for this path";
102 return;
103 }
104 for (uint64_t i = 0; i < numStates; ++i)
105 operandPossibleValues.insert(i);
106 }
107
108 allOperandValues.push_back(operandPossibleValues);
109 operandWidths.push_back(width);
110 }
111
112 // The shift for operand `i` is the sum of the widths of operands `i+1` to
113 // `n-1`.
114 llvm::SmallVector<unsigned> shifts(concatOp.getNumOperands(), 0);
115 for (int i = concatOp.getNumOperands() - 2; i >= 0; --i) {
116 shifts[i] = shifts[i + 1] + operandWidths[i + 1];
117 }
118
119 generateConcatenatedValues(allOperandValues, shifts, possibleValues);
120 return;
121 }
122
123 // --- Fallback Case ---
124 // If the operation is not recognized, assume all possible values for its
125 // bitwidth.
126
127 auto addrType = dyn_cast<IntegerType>(v.getType());
128 if (!addrType)
129 return; // Not an integer type we can analyze
130
131 unsigned bitWidth = addrType.getWidth();
132 // Again, use a threshold to avoid trying to enumerate 2^64 values.
133 if (bitWidth > 16) {
134 if (v.getDefiningOp())
135 v.getDefiningOp()->emitWarning()
136 << "Bitwidth " << bitWidth
137 << " too large (>16); abandoning analysis for this path";
138 return;
139 }
140
141 uint64_t numRegStates = 1ULL << bitWidth;
142 for (size_t i = 0; i < numRegStates; i++) {
143 possibleValues.insert(i);
144 }
145}
146
147static void addPossibleValues(llvm::SetVector<size_t> &possibleValues,
148 Value v) {
149 llvm::DenseSet<Value> visited;
150 addPossibleValuesImpl(possibleValues, v, visited);
151}
152
153/// Checks if two values compute the same function structurally.
154/// Values defined outside their respective regions (e.g., machine block
155/// arguments or fsm.variable results) are compared by SSA identity.
156/// Values defined by operations within their regions are compared using
157/// MLIR's OperationEquivalence with a custom equivalence callback.
158static bool areStructurallyEquivalent(Value a, Value b, Region &regionA,
159 Region &regionB) {
160 if (a == b)
161 return true;
162
163 Operation *opA = a.getDefiningOp();
164 Operation *opB = b.getDefiningOp();
165 if (!opA || !opB)
166 return false;
167
168 bool aIsLocal = regionA.isAncestor(opA->getParentRegion());
169 bool bIsLocal = regionB.isAncestor(opB->getParentRegion());
170 if (aIsLocal != bIsLocal)
171 return false;
172 if (!aIsLocal)
173 return false;
174
175 // Both local: compare result index and delegate structural comparison
176 // to MLIR's OperationEquivalence.
177 if (cast<OpResult>(a).getResultNumber() !=
178 cast<OpResult>(b).getResultNumber())
179 return false;
180
181 return OperationEquivalence::isEquivalentTo(
182 opA, opB,
183 [&](Value lhs, Value rhs) -> LogicalResult {
184 return success(areStructurallyEquivalent(lhs, rhs, regionA, regionB));
185 },
186 /*markEquivalent=*/nullptr, OperationEquivalence::Flags::IgnoreLocations);
187}
188
189/// RewritePattern that folds expressions in an action region that are
190/// structurally equivalent to guard conditions into boolean constants.
191
192class GuardConditionFoldPattern : public RewritePattern {
193public:
194 GuardConditionFoldPattern(MLIRContext *ctx,
195 ArrayRef<std::pair<Value, bool>> guardFacts,
196 Region &guardRegion, Region &actionRegion)
197 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/10, ctx),
198 guardFacts(guardFacts.begin(), guardFacts.end()),
199 guardRegion(guardRegion), actionRegion(actionRegion) {}
200
201 LogicalResult matchAndRewrite(Operation *op,
202 PatternRewriter &rewriter) const override {
203 if (!actionRegion.isAncestor(op->getParentRegion()))
204 return failure();
205
206 // Skip constants because replacing a constant with an identical constant
207 // creates an infinite loop in the greedy rewriter.
208 if (isa<hw::ConstantOp>(op))
209 return failure();
210
211 for (Value result : op->getResults()) {
212 if (!result.getType().isInteger(1) || result.use_empty())
213 continue;
214
215 for (auto [guardExpr, isTrue] : guardFacts) {
216 // Direct structural match.
217 if (areStructurallyEquivalent(guardExpr, result, guardRegion,
218 actionRegion))
219 return replaceWithConstant(result, isTrue, op, rewriter);
220
221 // ICmp predicate inversion: if the guard contains icmp P(X, Y),
222 // find icmp !P(X, Y) in the action and replace with !isTrue.
223 if (auto guardIcmp = guardExpr.getDefiningOp<ICmpOp>()) {
224 if (auto actionIcmp = dyn_cast<ICmpOp>(op)) {
225 if (actionIcmp.getPredicate() ==
226 ICmpOp::getNegatedPredicate(guardIcmp.getPredicate()) &&
227 areStructurallyEquivalent(guardIcmp.getLhs(),
228 actionIcmp.getLhs(), guardRegion,
229 actionRegion) &&
230 areStructurallyEquivalent(guardIcmp.getRhs(),
231 actionIcmp.getRhs(), guardRegion,
232 actionRegion))
233 return replaceWithConstant(result, !isTrue, op, rewriter);
234 }
235 }
236 }
237 }
238 return failure();
239 }
240
241private:
242 LogicalResult replaceWithConstant(Value result, bool constVal, Operation *op,
243 PatternRewriter &rewriter) const {
244 rewriter.setInsertionPointToStart(&actionRegion.front());
245 Value c = hw::ConstantOp::create(rewriter, op->getLoc(),
246 rewriter.getI1Type(), constVal ? 1 : 0);
247 rewriter.replaceAllUsesWith(result, c);
248 if (op->use_empty())
249 rewriter.eraseOp(op);
250 return success();
251 }
252
253 SmallVector<std::pair<Value, bool>> guardFacts;
254 Region &guardRegion;
255 Region &actionRegion;
256};
257
258/// Simplify action blocks by propagating guard conditions.
259/// When a transition is taken, its guard condition is known to be true.
260/// This function finds expressions in the action region that are structurally
261/// identical to the guard condition and replaces them with constant true.
262/// It also handles NOT (xor with true) and AND decomposition to propagate
263/// additional known values.
264static void simplifyActionWithGuard(TransitionOp transition,
265 OpBuilder &builder) {
266 Region &guardRegion = transition.getGuard();
267 Region &actionRegion = transition.getAction();
268
269 if (guardRegion.empty() || actionRegion.empty())
270 return;
271
272 auto guardReturn =
273 dyn_cast<fsm::ReturnOp>(guardRegion.front().getTerminator());
274 if (!guardReturn)
275 return;
276
277 Location loc = guardReturn.getLoc();
278 Value guardCondition = guardReturn.getOperand();
279
280 // Collect (expression, isTrue) pairs to propagate into the action.
281 SmallVector<std::pair<Value, bool>> guardFacts;
282
283 // Use a worklist to decompose AND expressions.
284 SmallVector<std::pair<Value, bool>> worklist;
285 worklist.push_back({guardCondition, true});
286
287 while (!worklist.empty()) {
288 auto [cond, isTrue] = worklist.pop_back_val();
289 guardFacts.push_back({cond, isTrue});
290
291 // Decompose NOT: xor(E, true) means E has the opposite truth value.
292 if (auto xorOp = cond.getDefiningOp<XorOp>()) {
293 if (xorOp.isBinaryNot() &&
294 guardRegion.isAncestor(xorOp->getParentRegion()))
295 guardFacts.push_back({xorOp.getOperand(0), !isTrue});
296 }
297
298 // Decompose AND: if the AND is true, each operand is true.
299 if (isTrue) {
300 if (auto andOp = cond.getDefiningOp<AndOp>()) {
301 if (guardRegion.isAncestor(andOp->getParentRegion())) {
302 for (Value operand : andOp.getOperands())
303 worklist.push_back({operand, true});
304 }
305 }
306 }
307 }
308
309 // Replace external guard values (block args or ops outside guard region)
310 // directly via replaceUsesWithIf.
311 for (auto [guardExpr, isTrue] : guardFacts) {
312 bool guardIsExternal =
313 !guardExpr.getDefiningOp() ||
314 !guardRegion.isAncestor(guardExpr.getDefiningOp()->getParentRegion());
315 if (guardIsExternal) {
316 builder.setInsertionPointToStart(&actionRegion.front());
317 auto constOp = hw::ConstantOp::create(builder, loc, guardExpr.getType(),
318 isTrue ? 1 : 0);
319 guardExpr.replaceUsesWithIf(constOp.getResult(), [&](OpOperand &use) {
320 if (!actionRegion.isAncestor(use.getOwner()->getParentRegion()))
321 return false;
322 // Never replace the variable operand of fsm.update because that is the
323 // assignment target, not a value to be folded.
324 if (auto updateOp = dyn_cast<UpdateOp>(use.getOwner()))
325 if (use.getOperandNumber() == 0)
326 return false;
327 return true;
328 });
329 }
330 }
331
332 // Use pattern rewriter for internal guard expression folding.
333 MLIRContext *ctx = transition.getContext();
334 RewritePatternSet patterns(ctx);
335 patterns.add<GuardConditionFoldPattern>(ctx, guardFacts, guardRegion,
336 actionRegion);
337 SmallVector<Operation *> actionOps;
338 actionRegion.walk([&](Operation *op) { actionOps.push_back(op); });
339 GreedyRewriteConfig config;
340 config.setScope(&actionRegion);
341 (void)applyOpPatternsGreedily(
342 actionOps, FrozenRewritePatternSet(std::move(patterns)), config);
343}
344
345/// Checks if a value is a constant or a tree of muxes with constant leaves.
346/// Uses an iterative approach with a visited set to handle cycles.
347static bool isConstantOrConstantTree(Value value) {
348 SmallVector<Value> worklist;
349 llvm::DenseSet<Value> visited;
350
351 worklist.push_back(value);
352 while (!worklist.empty()) {
353 Value current = worklist.pop_back_val();
354
355 // Skip if already visited (handles cycles).
356 if (!visited.insert(current).second)
357 continue;
358
359 Operation *definingOp = current.getDefiningOp();
360 if (!definingOp)
361 return false;
362
363 if (isa<hw::ConstantOp>(definingOp))
364 continue;
365
366 if (auto muxOp = dyn_cast<MuxOp>(definingOp)) {
367 worklist.push_back(muxOp.getTrueValue());
368 worklist.push_back(muxOp.getFalseValue());
369 continue;
370 }
371
372 // Not a constant or mux - not constant-like.
373 return false;
374 }
375 return true;
376}
377
378/// Pushes an ICmp equality comparison through a mux operation.
379/// This transforms `icmp eq (mux cond, x, y), b` into
380/// `mux cond, (icmp eq x, b), (icmp eq y, b)`.
381/// This simplification helps expose constant comparisons that can be folded
382/// during FSM extraction, making transition guards easier to analyze.
383LogicalResult pushIcmp(ICmpOp op, PatternRewriter &rewriter) {
384 APInt lhs, rhs;
385 if (op.getPredicate() == ICmpPredicate::eq &&
386 op.getLhs().getDefiningOp<MuxOp>() &&
387 (isConstantOrConstantTree(op.getLhs()) ||
388 op.getRhs().getDefiningOp<hw::ConstantOp>())) {
389 rewriter.setInsertionPointAfter(op);
390 auto mux = op.getLhs().getDefiningOp<MuxOp>();
391 Value x = mux.getTrueValue();
392 Value y = mux.getFalseValue();
393 Value b = op.getRhs();
394 Location loc = op.getLoc();
395 auto eq1 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, x, b);
396 auto eq2 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, y, b);
397 rewriter.replaceOpWithNewOp<MuxOp>(op, mux.getCond(), eq1.getResult(),
398 eq2.getResult());
399 return llvm::success();
400 }
401 if (op.getPredicate() == ICmpPredicate::eq &&
402 op.getRhs().getDefiningOp<MuxOp>() &&
403 (isConstantOrConstantTree(op.getRhs()) ||
404 op.getLhs().getDefiningOp<hw::ConstantOp>())) {
405 rewriter.setInsertionPointAfter(op);
406 auto mux = op.getRhs().getDefiningOp<MuxOp>();
407 Value x = mux.getTrueValue();
408 Value y = mux.getFalseValue();
409 Value b = op.getLhs();
410 Location loc = op.getLoc();
411 auto eq1 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, x, b);
412 auto eq2 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, y, b);
413 rewriter.replaceOpWithNewOp<MuxOp>(op, mux.getCond(), eq1.getResult(),
414 eq2.getResult());
415 return llvm::success();
416 }
417 return llvm::failure();
418}
419
420/// Iteratively builds all possible concatenated integer values from the
421/// Cartesian product of value sets.
422static void generateConcatenatedValues(
423 const llvm::SmallVector<llvm::SetVector<size_t>> &allOperandValues,
424 const llvm::SmallVector<unsigned> &shifts,
425 llvm::SetVector<size_t> &finalPossibleValues) {
426
427 if (allOperandValues.empty()) {
428 finalPossibleValues.insert(0);
429 return;
430 }
431
432 // Start with the values of the first operand, shifted appropriately.
433 llvm::SetVector<size_t> currentResults;
434 for (size_t val : allOperandValues[0])
435 currentResults.insert(val << shifts[0]);
436
437 // For each subsequent operand, combine with all existing partial results.
438 for (size_t operandIdx = 1; operandIdx < allOperandValues.size();
439 ++operandIdx) {
440 llvm::SetVector<size_t> nextResults;
441 unsigned shift = shifts[operandIdx];
442
443 for (size_t partialValue : currentResults) {
444 for (size_t val : allOperandValues[operandIdx]) {
445 nextResults.insert(partialValue | (val << shift));
446 }
447 }
448 currentResults = std::move(nextResults);
449 }
450
451 for (size_t val : currentResults)
452 finalPossibleValues.insert(val);
453}
454
455static llvm::MapVector<Value, int> intToRegMap(SmallVector<seq::CompRegOp> v,
456 int i) {
457 llvm::MapVector<Value, int> m;
458 for (size_t ci = 0; ci < v.size(); ci++) {
459 seq::CompRegOp reg = v[ci];
460 int bits = reg.getType().getIntOrFloatBitWidth();
461 int v = i & ((1 << bits) - 1);
462 m[reg] = v;
463 i = i >> bits;
464 }
465 return m;
466}
467
468static int regMapToInt(SmallVector<seq::CompRegOp> v,
469 llvm::DenseMap<Value, int> m) {
470 int i = 0;
471 int width = 0;
472 for (size_t ci = 0; ci < v.size(); ci++) {
473 seq::CompRegOp reg = v[ci];
474 i += m[reg] * 1ULL << width;
475 width += (reg.getType().getIntOrFloatBitWidth());
476 }
477 return i;
478}
479
480/// Computes the Cartesian product of a list of sets.
481static std::set<llvm::SmallVector<size_t>> calculateCartesianProduct(
482 const llvm::SmallVector<llvm::SetVector<size_t>> &valueSets) {
483 std::set<llvm::SmallVector<size_t>> product;
484 if (valueSets.empty()) {
485 // The Cartesian product of zero sets is a set containing one element:
486 // the empty tuple (represented here by an empty vector).
487 product.insert({});
488 return product;
489 }
490
491 // Initialize the product with the elements of the first set, each in its
492 // own vector.
493 for (size_t value : valueSets.front()) {
494 product.insert({value});
495 }
496
497 // Iteratively build the product. For each subsequent set, create a new
498 // temporary product by appending each of its elements to every combination
499 // already generated.
500 for (size_t i = 1; i < valueSets.size(); ++i) {
501 const auto &currentSet = valueSets[i];
502 if (currentSet.empty()) {
503 // The Cartesian product with an empty set results in an empty set.
504 return {};
505 }
506
507 std::set<llvm::SmallVector<size_t>> newProduct;
508 for (const auto &existingVector : product) {
509 for (size_t newValue : currentSet) {
510 llvm::SmallVector<size_t> newVector = existingVector;
511 newVector.push_back(newValue);
512 newProduct.insert(std::move(newVector));
513 }
514 }
515 product = std::move(newProduct);
516 }
517
518 return product;
519}
520
521static FrozenRewritePatternSet loadPatterns(MLIRContext &context) {
522
523 RewritePatternSet patterns(&context);
524 for (auto *dialect : context.getLoadedDialects())
525 dialect->getCanonicalizationPatterns(patterns);
526 ICmpOp::getCanonicalizationPatterns(patterns, &context);
527 AndOp::getCanonicalizationPatterns(patterns, &context);
528 XorOp::getCanonicalizationPatterns(patterns, &context);
529 MuxOp::getCanonicalizationPatterns(patterns, &context);
530 ConcatOp::getCanonicalizationPatterns(patterns, &context);
531 ExtractOp::getCanonicalizationPatterns(patterns, &context);
532 AddOp::getCanonicalizationPatterns(patterns, &context);
533 OrOp::getCanonicalizationPatterns(patterns, &context);
534 MulOp::getCanonicalizationPatterns(patterns, &context);
535 hw::ConstantOp::getCanonicalizationPatterns(patterns, &context);
536 TransitionOp::getCanonicalizationPatterns(patterns, &context);
537 StateOp::getCanonicalizationPatterns(patterns, &context);
538 MachineOp::getCanonicalizationPatterns(patterns, &context);
539 patterns.add(pushIcmp);
540 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
541 return frozenPatterns;
542}
543
544static LogicalResult
545getReachableStates(llvm::SetVector<size_t> &visitableStates,
546 HWModuleOp moduleOp, size_t currentStateIndex,
547 SmallVector<seq::CompRegOp> registers, OpBuilder opBuilder,
548 bool isInitialState) {
549
550 IRMapping mapping;
551 auto clonedBody =
552 llvm::dyn_cast<HWModuleOp>(opBuilder.clone(*moduleOp, mapping));
553
554 llvm::MapVector<Value, int> stateMap =
555 intToRegMap(registers, currentStateIndex);
556 Operation *terminator = clonedBody.getBody().front().getTerminator();
557 auto output = dyn_cast<hw::OutputOp>(terminator);
558 SmallVector<Value> values;
559
560 for (auto [originalRegValue, constStateValue] : stateMap) {
561
562 Value clonedRegValue = mapping.lookup(originalRegValue);
563 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
564 auto reg = cast<seq::CompRegOp>(clonedRegOp);
565 Type constantType = reg.getType();
566 IntegerAttr constantAttr =
567 opBuilder.getIntegerAttr(constantType, constStateValue);
568 opBuilder.setInsertionPoint(clonedRegOp);
569 auto otherStateConstant =
570 hw::ConstantOp::create(opBuilder, reg.getLoc(), constantAttr);
571 // If the register input is self-referential (input == output), use the
572 // constant we're replacing it with. Otherwise, the value would become
573 // dangling after we erase the register.
574 Value regInput = reg.getInput();
575 if (regInput == clonedRegValue)
576 values.push_back(otherStateConstant.getResult());
577 else
578 values.push_back(regInput);
579 clonedRegValue.replaceAllUsesWith(otherStateConstant.getResult());
580 reg.erase();
581 }
582 opBuilder.setInsertionPointToEnd(clonedBody.front().getBlock());
583 auto newOutput = hw::OutputOp::create(opBuilder, output.getLoc(), values);
584 output.erase();
585 FrozenRewritePatternSet frozenPatterns = loadPatterns(*moduleOp.getContext());
586
587 SmallVector<Operation *> opsToProcess;
588 clonedBody.walk([&](Operation *op) { opsToProcess.push_back(op); });
589
590 bool changed = false;
591 GreedyRewriteConfig config;
592 if (failed(applyOpPatternsGreedily(opsToProcess, frozenPatterns, config,
593 &changed)))
594 return failure();
595
596 llvm::SmallVector<llvm::SetVector<size_t>> pv;
597 for (size_t j = 0; j < newOutput.getNumOperands(); j++) {
598 llvm::SetVector<size_t> possibleValues;
599
600 Value v = newOutput.getOperand(j);
601 addPossibleValues(possibleValues, v);
602 pv.push_back(possibleValues);
603 }
604 std::set<llvm::SmallVector<size_t>> flipped = calculateCartesianProduct(pv);
605 for (llvm::SmallVector<size_t> v : flipped) {
606 llvm::DenseMap<Value, int> m;
607 for (size_t k = 0; k < v.size(); k++) {
608 seq::CompRegOp r = registers[k];
609 m[r] = v[k];
610 }
611
612 int i = regMapToInt(registers, m);
613 visitableStates.insert(i);
614 }
615
616 clonedBody.erase();
617 return success();
618}
619
620// A converter class to handle the logic of converting a single hw.module.
621class HWModuleOpConverter {
622public:
623 HWModuleOpConverter(OpBuilder &builder, HWModuleOp moduleOp,
624 ArrayRef<std::string> stateRegNames)
625 : moduleOp(moduleOp), opBuilder(builder), stateRegNames(stateRegNames) {}
626 LogicalResult run() {
627 SmallVector<seq::CompRegOp> stateRegs;
628 SmallVector<seq::CompRegOp> variableRegs;
629 Value foundClock, foundReset = nullptr;
630 WalkResult walkResult = moduleOp.walk([&](seq::CompRegOp reg) {
631 auto clk = reg.getClk();
632 auto reset = reg.getReset();
633 if (foundClock) {
634 if (clk != foundClock) {
635 reg.emitError("All registers must have the same clock signal.");
636 return WalkResult::interrupt();
637 }
638 } else {
639 foundClock = clk;
640 }
641
642 if (reset) {
643 if (foundReset) {
644 if (reset != foundReset) {
645 reg.emitError("All registers must have the same reset signal.");
646 return WalkResult::interrupt();
647 }
648 } else {
649 foundReset = reset;
650 }
651 }
652
653 // Check that the register type is an integer.
654 if (!isa<IntegerType>(reg.getType())) {
655 reg.emitError("FSM extraction only supports integer-typed registers");
656 return WalkResult::interrupt();
657 }
658 if (isStateRegister(reg)) {
659 stateRegs.push_back(reg);
660 } else {
661 variableRegs.push_back(reg);
662 }
663 return WalkResult::advance();
664 });
665 if (walkResult.wasInterrupted())
666 return failure();
667 if (stateRegs.empty()) {
668 emitError(moduleOp.getLoc())
669 << "Cannot find state register in this FSM. Use the state-regs "
670 "option to specify which registers are state registers.";
671 return failure();
672 }
673 SmallVector<seq::CompRegOp> registers;
674 for (seq::CompRegOp c : stateRegs) {
675 registers.push_back(c);
676 }
677
678 llvm::DenseMap<size_t, StateOp> stateToStateOp;
679 llvm::DenseMap<StateOp, size_t> stateOpToState;
680 // Collect reset arguments to exclude from the FSM's function type.
681 // All CompReg reset signals are ignored during FSM extraction since the
682 // FSM dialect does not have an explicit reset concept. The reset behavior
683 // is only captured in the initial state value.
684 llvm::DenseSet<size_t> asyncResetArguments;
685 Location loc = moduleOp.getLoc();
686 SmallVector<Type> inputTypes = moduleOp.getInputTypes();
687
688 // Create a new FSM machine with the current state.
689 auto resultTypes = moduleOp.getOutputTypes();
690 FunctionType machineType =
691 FunctionType::get(opBuilder.getContext(), inputTypes, resultTypes);
692 StringRef machineName = moduleOp.getName();
693
694 llvm::DenseMap<Value, int> initialStateMap;
695 for (seq::CompRegOp reg : moduleOp.getOps<seq::CompRegOp>()) {
696 Value resetValue = reg.getResetValue();
697 hw::ConstantOp definingConstant;
698 if (resetValue) {
699 definingConstant = resetValue.getDefiningOp<hw::ConstantOp>();
700 } else {
701 // Assume that registers without a reset start at 0
702 reg.emitWarning("Assuming register with no reset starts with value 0");
703 definingConstant =
704 hw::ConstantOp::create(opBuilder, reg.getLoc(), reg.getType(), 0);
705 }
706 if (!definingConstant) {
707 reg->emitError(
708 "cannot find defining constant for reset value of register");
709 return failure();
710 }
711 int resetValueInt =
712 definingConstant.getValueAttr().getValue().getZExtValue();
713 initialStateMap[reg] = resetValueInt;
714 }
715 int initialStateIndex = regMapToInt(registers, initialStateMap);
716
717 std::string initialStateName = "state_" + std::to_string(initialStateIndex);
718
719 // Preserve argument and result names, which are stored as attributes.
720 SmallVector<NamedAttribute> machineAttrs;
721 if (auto argNames = moduleOp->getAttrOfType<ArrayAttr>("argNames"))
722 machineAttrs.emplace_back(opBuilder.getStringAttr("argNames"), argNames);
723 if (auto resNames = moduleOp->getAttrOfType<ArrayAttr>("resultNames"))
724 machineAttrs.emplace_back(opBuilder.getStringAttr("resNames"), resNames);
725
726 // The builder for fsm.MachineOp will create the body region and block
727 // arguments.
728 opBuilder.setInsertionPoint(moduleOp);
729 auto machine =
730 MachineOp::create(opBuilder, loc, machineName, initialStateName,
731 machineType, machineAttrs);
732
733 OpBuilder::InsertionGuard guard(opBuilder);
734 opBuilder.setInsertionPointToStart(&machine.getBody().front());
735 llvm::MapVector<seq::CompRegOp, VariableOp> variableMap;
736 for (seq::CompRegOp varReg : variableRegs) {
737 TypedValue<Type> initialValue = varReg.getResetValue();
738 hw::ConstantOp definingConstant;
739 if (initialValue) {
740 definingConstant = initialValue.getDefiningOp<hw::ConstantOp>();
741 } else {
742 // Assume that registers without a reset start at 0
743 varReg.emitWarning(
744 "Assuming register with no reset starts with value 0");
745 definingConstant = hw::ConstantOp::create(opBuilder, varReg.getLoc(),
746 varReg.getType(), 0);
747 }
748 if (!definingConstant) {
749 varReg->emitError("cannot find defining constant for reset value of "
750 "variable register");
751 return failure();
752 }
753 auto variableOp = VariableOp::create(
754 opBuilder, varReg->getLoc(), varReg.getInput().getType(),
755 definingConstant.getValueAttr(), varReg.getName().value_or("var"));
756 variableMap[varReg] = variableOp;
757 }
758
759 // Load rewrite patterns used for canonicalizing the generated FSM.
760 FrozenRewritePatternSet frozenPatterns =
761 loadPatterns(*moduleOp.getContext());
762
763 SetVector<int> reachableStates;
764 SmallVector<int> worklist;
765
766 worklist.push_back(initialStateIndex);
767 reachableStates.insert(initialStateIndex);
768 // Process states in BFS order. The worklist grows as new reachable states
769 // are discovered, so we use an index-based loop.
770 for (unsigned i = 0; i < worklist.size(); ++i) {
771
772 int currentStateIndex = worklist[i];
773
774 llvm::MapVector<Value, int> stateMap =
775 intToRegMap(registers, currentStateIndex);
776
777 opBuilder.setInsertionPointToEnd(&machine.getBody().front());
778
779 StateOp stateOp;
780
781 if (!stateToStateOp.contains(currentStateIndex)) {
782 stateOp = StateOp::create(opBuilder, loc,
783 "state_" + std::to_string(currentStateIndex));
784 stateToStateOp.insert({currentStateIndex, stateOp});
785 stateOpToState.insert({stateOp, currentStateIndex});
786 } else {
787 stateOp = stateToStateOp.lookup(currentStateIndex);
788 }
789 Region &outputRegion = stateOp.getOutput();
790 Block *outputBlock = &outputRegion.front();
791 opBuilder.setInsertionPointToStart(outputBlock);
792 IRMapping mapping;
793 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), outputRegion,
794 outputBlock->getIterator(), mapping);
795 outputBlock->erase();
796
797 auto *terminator = outputRegion.front().getTerminator();
798 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
799 assert(hwOutputOp && "Expected terminator to be hw.output op");
800
801 // Position the builder to insert the new terminator right before the
802 // old one.
803 OpBuilder::InsertionGuard stateGuard(opBuilder);
804 opBuilder.setInsertionPoint(hwOutputOp);
805
806 // Create the new fsm.OutputOp with the same operands.
807
808 fsm::OutputOp::create(opBuilder, hwOutputOp.getLoc(),
809 hwOutputOp.getOperands());
810
811 // Erase the old terminator.
812 hwOutputOp.erase();
813
814 // Iterate through the state configuration to replace registers
815 // with constants.
816 for (auto &[originalRegValue, variableOp] : variableMap) {
817 Value clonedRegValue = mapping.lookup(originalRegValue);
818 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
819 auto reg = cast<seq::CompRegOp>(clonedRegOp);
820 const auto res = variableOp.getResult();
821 clonedRegValue.replaceAllUsesWith(res);
822 reg.erase();
823 }
824 for (auto const &[originalRegValue, constStateValue] : stateMap) {
825 // Find the cloned register's result value using the mapping.
826 Value clonedRegValue = mapping.lookup(originalRegValue);
827 assert(clonedRegValue &&
828 "Original register value not found in mapping");
829 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
830
831 assert(clonedRegOp && "Cloned value must have a defining op");
832 opBuilder.setInsertionPoint(clonedRegOp);
833 auto r = cast<seq::CompRegOp>(clonedRegOp);
834 TypedValue<IntegerType> registerReset = r.getReset();
835 if (registerReset) {
836 if (BlockArgument blockArg = dyn_cast<BlockArgument>(registerReset)) {
837 asyncResetArguments.insert(blockArg.getArgNumber());
838 auto falseConst = hw::ConstantOp::create(
839 opBuilder, blockArg.getLoc(), clonedRegValue.getType(), 0);
840 blockArg.replaceAllUsesWith(falseConst.getResult());
841 }
842 if (auto xorOp = registerReset.getDefiningOp<XorOp>()) {
843 if (xorOp.isBinaryNot()) {
844 Value rhs = xorOp.getOperand(0);
845 if (BlockArgument blockArg = dyn_cast<BlockArgument>(rhs)) {
846 asyncResetArguments.insert(blockArg.getArgNumber());
847 auto trueConst = hw::ConstantOp::create(
848 opBuilder, blockArg.getLoc(), blockArg.getType(), 1);
849 blockArg.replaceAllUsesWith(trueConst.getResult());
850 }
851 }
852 }
853 }
854 auto constantOp =
855 hw::ConstantOp::create(opBuilder, clonedRegValue.getLoc(),
856 clonedRegValue.getType(), constStateValue);
857 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
858 clonedRegOp->erase();
859 }
860 GreedyRewriteConfig config;
861 SmallVector<Operation *> opsToProcess;
862 outputRegion.walk([&](Operation *op) { opsToProcess.push_back(op); });
863 // Replace references to arguments in the output block with
864 // arguments at the top level.
865 for (auto arg : outputRegion.front().getArguments()) {
866 int argIndex = arg.getArgNumber();
867 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
868 arg.replaceAllUsesWith(topLevelArg);
869 }
870 outputRegion.front().eraseArguments(
871 [](BlockArgument arg) { return true; });
872 FrozenRewritePatternSet patterns(opBuilder.getContext());
873 config.setScope(&outputRegion);
874
875 bool changed = false;
876 if (failed(applyOpPatternsGreedily(opsToProcess, patterns, config,
877 &changed)))
878 return failure();
879 opBuilder.setInsertionPoint(stateOp);
880 // hw.module uses graph regions that allow cycles (e.g., registers feeding
881 // back into themselves). By this point we've replaced all registers with
882 // constants, but cycles in purely combinational logic (e.g., cyclic
883 // muxes) may still exist. Such cycles cannot be converted to FSM.
884 bool sorted = sortTopologically(&outputRegion.front());
885 if (!sorted) {
886 moduleOp.emitError()
887 << "cannot convert module with combinational cycles to FSM";
888 return failure();
889 }
890 Region &transitionRegion = stateOp.getTransitions();
891 llvm::SetVector<size_t> visitableStates;
892 if (failed(getReachableStates(visitableStates, moduleOp,
893 currentStateIndex, registers, opBuilder,
894 currentStateIndex == initialStateIndex)))
895 return failure();
896 for (size_t j : visitableStates) {
897 StateOp toState;
898 if (!stateToStateOp.contains(j)) {
899 opBuilder.setInsertionPointToEnd(&machine.getBody().front());
900 toState =
901 StateOp::create(opBuilder, loc, "state_" + std::to_string(j));
902 stateToStateOp.insert({j, toState});
903 stateOpToState.insert({toState, j});
904 } else {
905 toState = stateToStateOp[j];
906 }
907 opBuilder.setInsertionPointToStart(&transitionRegion.front());
908 auto transitionOp =
909 TransitionOp::create(opBuilder, loc, "state_" + std::to_string(j));
910 Region &guardRegion = transitionOp.getGuard();
911 opBuilder.createBlock(&guardRegion);
912
913 Block &guardBlock = guardRegion.front();
914
915 opBuilder.setInsertionPointToStart(&guardBlock);
916 IRMapping mapping;
917 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), guardRegion,
918 guardBlock.getIterator(), mapping);
919 guardBlock.erase();
920 Block &newGuardBlock = guardRegion.front();
921 Operation *terminator = newGuardBlock.getTerminator();
922 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
923 assert(hwOutputOp && "Expected terminator to be hw.output op");
924
925 llvm::MapVector<Value, int> toStateMap = intToRegMap(registers, j);
926 SmallVector<Value> equalityChecks;
927 for (auto &[originalRegValue, variableOp] : variableMap) {
928 opBuilder.setInsertionPointToStart(&newGuardBlock);
929 Value clonedRegValue = mapping.lookup(originalRegValue);
930 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
931 auto reg = cast<seq::CompRegOp>(clonedRegOp);
932 const auto res = variableOp.getResult();
933 clonedRegValue.replaceAllUsesWith(res);
934 reg.erase();
935 }
936 for (auto const &[originalRegValue, constStateValue] : toStateMap) {
937
938 Value clonedRegValue = mapping.lookup(originalRegValue);
939 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
940 opBuilder.setInsertionPoint(clonedRegOp);
941 auto r = cast<seq::CompRegOp>(clonedRegOp);
942
943 Value registerInput = r.getInput();
944 TypedValue<IntegerType> registerReset = r.getReset();
945 if (registerReset) {
946 if (BlockArgument blockArg =
947 dyn_cast<BlockArgument>(registerReset)) {
948 auto falseConst = hw::ConstantOp::create(
949 opBuilder, blockArg.getLoc(), clonedRegValue.getType(), 0);
950 blockArg.replaceAllUsesWith(falseConst.getResult());
951 }
952 if (auto xorOp = registerReset.getDefiningOp<XorOp>()) {
953 if (xorOp.isBinaryNot()) {
954 Value rhs = xorOp.getOperand(0);
955 if (BlockArgument blockArg = dyn_cast<BlockArgument>(rhs)) {
956 auto trueConst = hw::ConstantOp::create(
957 opBuilder, blockArg.getLoc(), blockArg.getType(), 1);
958 blockArg.replaceAllUsesWith(trueConst.getResult());
959 }
960 }
961 }
962 }
963 Type constantType = registerInput.getType();
964 IntegerAttr constantAttr =
965 opBuilder.getIntegerAttr(constantType, constStateValue);
966 auto otherStateConstant = hw::ConstantOp::create(
967 opBuilder, hwOutputOp.getLoc(), constantAttr);
968
969 auto doesEqual =
970 ICmpOp::create(opBuilder, hwOutputOp.getLoc(), ICmpPredicate::eq,
971 registerInput, otherStateConstant.getResult());
972 equalityChecks.push_back(doesEqual.getResult());
973 }
974 opBuilder.setInsertionPoint(hwOutputOp);
975 auto allEqualCheck = AndOp::create(opBuilder, hwOutputOp.getLoc(),
976 equalityChecks, false);
977 fsm::ReturnOp::create(opBuilder, hwOutputOp.getLoc(),
978 allEqualCheck.getResult());
979 hwOutputOp.erase();
980 for (BlockArgument arg : newGuardBlock.getArguments()) {
981 int argIndex = arg.getArgNumber();
982 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
983 arg.replaceAllUsesWith(topLevelArg);
984 }
985 newGuardBlock.eraseArguments([](BlockArgument arg) { return true; });
986 llvm::MapVector<Value, int> fromStateMap =
987 intToRegMap(registers, currentStateIndex);
988 for (auto const &[originalRegValue, constStateValue] : fromStateMap) {
989 Value clonedRegValue = mapping.lookup(originalRegValue);
990 assert(clonedRegValue &&
991 "Original register value not found in mapping");
992 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
993 assert(clonedRegOp && "Cloned value must have a defining op");
994 opBuilder.setInsertionPoint(clonedRegOp);
995 auto constantOp =
996 hw::ConstantOp::create(opBuilder, clonedRegValue.getLoc(),
997 clonedRegValue.getType(), constStateValue);
998 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
999 clonedRegOp->erase();
1000 }
1001 Region &actionRegion = transitionOp.getAction();
1002 if (!variableRegs.empty()) {
1003 Block *actionBlock = opBuilder.createBlock(&actionRegion);
1004 opBuilder.setInsertionPointToStart(actionBlock);
1005 IRMapping mapping;
1006 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), actionRegion,
1007 actionBlock->getIterator(), mapping);
1008 actionBlock->erase();
1009 Block &newActionBlock = actionRegion.front();
1010 for (BlockArgument arg : newActionBlock.getArguments()) {
1011 int argIndex = arg.getArgNumber();
1012 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
1013 arg.replaceAllUsesWith(topLevelArg);
1014 }
1015 newActionBlock.eraseArguments([](BlockArgument arg) { return true; });
1016 for (auto &[originalRegValue, variableOp] : variableMap) {
1017 Value clonedRegValue = mapping.lookup(originalRegValue);
1018 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
1019 auto reg = cast<seq::CompRegOp>(clonedRegOp);
1020 opBuilder.setInsertionPointToStart(&newActionBlock);
1021 UpdateOp::create(opBuilder, reg.getLoc(), variableOp,
1022 reg.getInput());
1023 const Value res = variableOp.getResult();
1024 clonedRegValue.replaceAllUsesWith(res);
1025 reg.erase();
1026 }
1027 Operation *terminator = actionRegion.back().getTerminator();
1028 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
1029 assert(hwOutputOp && "Expected terminator to be hw.output op");
1030 hwOutputOp.erase();
1031
1032 for (auto const &[originalRegValue, constStateValue] : fromStateMap) {
1033 Value clonedRegValue = mapping.lookup(originalRegValue);
1034 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
1035 opBuilder.setInsertionPoint(clonedRegOp);
1036 auto constantOp = hw::ConstantOp::create(
1037 opBuilder, clonedRegValue.getLoc(), clonedRegValue.getType(),
1038 constStateValue);
1039 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
1040 clonedRegOp->erase();
1041 }
1042
1043 GreedyRewriteConfig config;
1044 SmallVector<Operation *> opsToProcess;
1045 actionRegion.walk([&](Operation *op) { opsToProcess.push_back(op); });
1046 config.setScope(&actionRegion);
1047
1048 bool changed = false;
1049 if (failed(applyOpPatternsGreedily(opsToProcess, frozenPatterns,
1050 config, &changed)))
1051 return failure();
1052
1053 // hw.module uses graph regions that allow cycles. By this point
1054 // we've replaced all registers with constants, but cycles in purely
1055 // combinational logic may still exist.
1056 bool actionSorted = sortTopologically(&actionRegion.front());
1057 if (!actionSorted) {
1058 moduleOp.emitError()
1059 << "cannot convert module with combinational cycles to FSM";
1060 return failure();
1061 }
1062 }
1063
1064 // hw.module uses graph regions that allow cycles. By this point
1065 // we've replaced all registers with constants, but cycles in purely
1066 // combinational logic may still exist.
1067 bool guardSorted = sortTopologically(&newGuardBlock);
1068 if (!guardSorted) {
1069 moduleOp.emitError()
1070 << "cannot convert module with combinational cycles to FSM";
1071 return failure();
1072 }
1073 SmallVector<Operation *> outputOps;
1074 stateOp.getOutput().walk(
1075 [&](Operation *op) { outputOps.push_back(op); });
1076
1077 bool changed = false;
1078 GreedyRewriteConfig config;
1079 config.setScope(&stateOp.getOutput());
1080 LogicalResult converged = applyOpPatternsGreedily(
1081 outputOps, frozenPatterns, config, &changed);
1082 assert(succeeded(converged) && "canonicalization failed to converge");
1083 SmallVector<Operation *> transitionOps;
1084 stateOp.getTransitions().walk(
1085 [&](Operation *op) { transitionOps.push_back(op); });
1086
1087 GreedyRewriteConfig config2;
1088 config2.setScope(&stateOp.getTransitions());
1089 if (failed(applyOpPatternsGreedily(transitionOps, frozenPatterns,
1090 config2, &changed))) {
1091 return failure();
1092 }
1093
1094 // Propagate guard conditions into action blocks to eliminate
1095 // redundant muxes. The guard already checks the transition
1096 // condition, so the action block can assume it holds.
1097 for (TransitionOp transition :
1098 stateOp.getTransitions().getOps<TransitionOp>())
1099 simplifyActionWithGuard(transition, opBuilder);
1100
1101 // Re-canonicalize after guard propagation to clean up dead ops.
1102 {
1103 SmallVector<Operation *> postOps;
1104 stateOp.getTransitions().walk(
1105 [&](Operation *op) { postOps.push_back(op); });
1106 GreedyRewriteConfig postConfig;
1107 postConfig.setScope(&stateOp.getTransitions());
1108 if (failed(applyOpPatternsGreedily(postOps, frozenPatterns,
1109 postConfig, &changed)))
1110 return failure();
1111 }
1112
1113 for (TransitionOp transition :
1114 stateOp.getTransitions().getOps<TransitionOp>()) {
1115 StateOp nextState = transition.getNextStateOp();
1116 int nextStateIndex = stateOpToState.lookup(nextState);
1117 auto guardConst = transition.getGuardReturn()
1118 .getOperand()
1119 .getDefiningOp<hw::ConstantOp>();
1120 bool nextStateIsReachable =
1121 !guardConst || (guardConst.getValueAttr().getInt() != 0);
1122 // If we find a valid next state and haven't seen it before, add it to
1123 // the worklist and the set of reachable states.
1124 if (nextStateIsReachable &&
1125 !reachableStates.contains(nextStateIndex)) {
1126 worklist.push_back(nextStateIndex);
1127 reachableStates.insert(nextStateIndex);
1128 }
1129 }
1130 }
1131 }
1132
1133 // Clean up unreachable states. States without an output region are
1134 // placeholder states that were created during reachability analysis but
1135 // never populated (i.e., they are unreachable from the initial state).
1136 SmallVector<StateOp> statesToErase;
1137
1138 // Collect unreachable states (those without an output op).
1139 for (StateOp stateOp : machine.getOps<StateOp>()) {
1140 if (!stateOp.getOutputOp()) {
1141 statesToErase.push_back(stateOp);
1142 }
1143 }
1144
1145 // Erase states in a separate loop to avoid iterator invalidation. We first
1146 // collect all states to erase, then iterate over that list. This is
1147 // necessary because erasing a state while iterating over machine.getOps()
1148 // would invalidate the iterator.
1149 for (StateOp stateOp : statesToErase) {
1150 for (TransitionOp transition : machine.getOps<TransitionOp>()) {
1151 if (transition.getNextStateOp().getSymName() == stateOp.getSymName()) {
1152 transition.erase();
1153 }
1154 }
1155 stateOp.erase();
1156 }
1157
1158 llvm::DenseSet<BlockArgument> asyncResetBlockArguments;
1159 for (auto arg : machine.getBody().front().getArguments()) {
1160 if (asyncResetArguments.contains(arg.getArgNumber())) {
1161 asyncResetBlockArguments.insert(arg);
1162 }
1163 }
1164
1165 // Emit a warning if reset signals were detected and removed.
1166 // The FSM dialect does not support reset signals, so the reset behavior
1167 // is only captured in the initial state. The original reset triggering
1168 // mechanism is not preserved.
1169 if (!asyncResetBlockArguments.empty()) {
1170 moduleOp.emitWarning()
1171 << "reset signals detected and removed from FSM; "
1172 "reset behavior is captured only in the initial state";
1173 }
1174
1175 Block &front = machine.getBody().front();
1176 front.eraseArguments([&](BlockArgument arg) {
1177 return asyncResetBlockArguments.contains(arg);
1178 });
1179
1180 if (llvm::any_of(front.getArguments(), [](BlockArgument arg) {
1181 return arg.getType() == seq::ClockType::get(arg.getContext()) &&
1182 arg.hasNUsesOrMore(1);
1183 })) {
1184 moduleOp.emitError("Clock uses outside register clocking are not "
1185 "currently supported.");
1186 return failure();
1187 }
1188 machine.getBody().front().eraseArguments([&](BlockArgument arg) {
1189 return arg.getType() == seq::ClockType::get(arg.getContext());
1190 });
1191 FunctionType oldFunctionType = machine.getFunctionType();
1192 SmallVector<Type> inputsWithoutClock;
1193 for (unsigned int i = 0; i < oldFunctionType.getNumInputs(); i++) {
1194 Type input = oldFunctionType.getInput(i);
1195 if (input != seq::ClockType::get(input.getContext()) &&
1196 !asyncResetArguments.contains(i))
1197 inputsWithoutClock.push_back(input);
1198 }
1199
1200 FunctionType newFunctionType = FunctionType::get(
1201 opBuilder.getContext(), inputsWithoutClock, resultTypes);
1202
1203 machine.setFunctionType(newFunctionType);
1204 moduleOp.erase();
1205 return success();
1206 }
1207
1208private:
1209 /// Helper function to determine if a register is a state register.
1210 bool isStateRegister(seq::CompRegOp reg) const {
1211 auto regName = reg.getName();
1212 if (!regName)
1213 return false;
1214
1215 // If user specified state registers, check if this register's name matches
1216 // any of them.
1217 if (!stateRegNames.empty()) {
1218 return llvm::is_contained(stateRegNames, regName->str());
1219 }
1220
1221 // Default behavior: infer state registers by checking if the name contains
1222 // "state".
1223 return regName->contains("state");
1224 }
1225
1226 HWModuleOp moduleOp;
1227 OpBuilder &opBuilder;
1228 ArrayRef<std::string> stateRegNames;
1229};
1230
1231} // namespace
1232
1233namespace {
1234struct CoreToFSMPass : public circt::impl::ConvertCoreToFSMBase<CoreToFSMPass> {
1235 using ConvertCoreToFSMBase<CoreToFSMPass>::ConvertCoreToFSMBase;
1236
1237 void runOnOperation() override {
1238 auto module = getOperation();
1239 OpBuilder builder(module);
1240
1241 SmallVector<HWModuleOp> modules;
1242 for (auto hwModule : module.getOps<HWModuleOp>()) {
1243 modules.push_back(hwModule);
1244 }
1245
1246 // Check for hw.instance operations - instance conversion is not supported.
1247 for (auto hwModule : modules) {
1248 for (auto instance : hwModule.getOps<hw::InstanceOp>()) {
1249 instance.emitError() << "instance conversion is not yet supported";
1250 signalPassFailure();
1251 return;
1252 }
1253 }
1254
1255 for (auto hwModule : modules) {
1256 builder.setInsertionPoint(hwModule);
1257 HWModuleOpConverter converter(builder, hwModule, stateRegs);
1258 if (failed(converter.run())) {
1259 signalPassFailure();
1260 return;
1261 }
1262 }
1263 }
1264};
1265} // namespace
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
create(*operands)
Definition fsm.py:160
create(name)
Definition fsm.py:144
create(to_state)
Definition fsm.py:110
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition comb.py:1
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition codegen.py:445
Definition fsm.py:1
Definition hw.py:1
Definition seq.py:1
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21