CIRCT 23.0.0git
Loading...
Searching...
No Matches
CoreToFSM.cpp
Go to the documentation of this file.
1//===CoreToFSM.cpp - Convert Core Dialects (HW + Seq + Comb) to FSM Dialect===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17#include "circt/Support/LLVM.h"
18#include "mlir/Analysis/TopologicalSortUtils.h"
19#include "mlir/IR/Block.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinAttributes.h"
22#include "mlir/IR/Diagnostics.h"
23#include "mlir/IR/IRMapping.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/IR/Value.h"
26#include "mlir/Pass/Pass.h"
27#include "mlir/Pass/PassManager.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#include "mlir/Transforms/Passes.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/DenseSet.h"
33#include "llvm/ADT/SmallVector.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/LogicalResult.h"
36#include <string>
37
38namespace circt {
39#define GEN_PASS_DEF_CONVERTCORETOFSM
40#include "circt/Conversion/Passes.h.inc"
41} // namespace circt
42
43using namespace mlir;
44using namespace circt;
45using namespace hw;
46using namespace comb;
47using namespace fsm;
48
49namespace {
50
51// Forward declaration for our helper function
52static void generateConcatenatedValues(
53 const llvm::SmallVector<llvm::SetVector<size_t>> &allOperandValues,
54 const llvm::SmallVector<unsigned> &shifts,
55 llvm::SetVector<size_t> &finalPossibleValues);
56
57/// Internal helper with visited set to detect cycles.
58static void addPossibleValuesImpl(llvm::SetVector<size_t> &possibleValues,
59 Value v, llvm::DenseSet<Value> &visited) {
60 // Detect cycles - if we've seen this value before, skip it.
61 if (!visited.insert(v).second)
62 return;
63
64 if (auto c = dyn_cast_or_null<hw::ConstantOp>(v.getDefiningOp())) {
65 possibleValues.insert(c.getValueAttr().getValue().getZExtValue());
66 return;
67 }
68 if (auto m = dyn_cast_or_null<MuxOp>(v.getDefiningOp())) {
69 addPossibleValuesImpl(possibleValues, m.getTrueValue(), visited);
70 addPossibleValuesImpl(possibleValues, m.getFalseValue(), visited);
71 return;
72 }
73
74 if (auto concatOp = dyn_cast_or_null<ConcatOp>(v.getDefiningOp())) {
75 llvm::SmallVector<llvm::SetVector<size_t>> allOperandValues;
76 llvm::SmallVector<unsigned> operandWidths;
77
78 for (Value operand : concatOp.getOperands()) {
79 llvm::SetVector<size_t> operandPossibleValues;
80 addPossibleValuesImpl(operandPossibleValues, operand, visited);
81
82 // It's crucial to handle the case where a sub-computation is too complex.
83 // If we can't determine specific values for an operand, we must
84 // pessimistically assume it can be any value its bitwidth allows.
85 auto opType = dyn_cast<IntegerType>(operand.getType());
86 // comb.concat only accepts signless integer operands by definition in
87 // CIRCT's type system, so this assertion should always hold for valid IR.
88 assert(opType && "comb.concat operand must be an integer type");
89 unsigned width = opType.getWidth();
90 if (operandPossibleValues.empty()) {
91 uint64_t numStates = 1ULL << width;
92 // Add a threshold to prevent combinatorial explosion on large unknown
93 // inputs.
94 if (numStates > 256) { // Heuristic threshold
95 // If the search space is too large, we abandon the analysis for this
96 // path. The outer function will fall back to its own full-range
97 // default.
98 v.getDefiningOp()->emitWarning()
99 << "Search space too large (>" << 256
100 << " states) for operand with bitwidth " << width
101 << "; abandoning analysis for this path";
102 return;
103 }
104 for (uint64_t i = 0; i < numStates; ++i)
105 operandPossibleValues.insert(i);
106 }
107
108 allOperandValues.push_back(operandPossibleValues);
109 operandWidths.push_back(width);
110 }
111
112 // The shift for operand `i` is the sum of the widths of operands `i+1` to
113 // `n-1`.
114 llvm::SmallVector<unsigned> shifts(concatOp.getNumOperands(), 0);
115 for (int i = concatOp.getNumOperands() - 2; i >= 0; --i) {
116 shifts[i] = shifts[i + 1] + operandWidths[i + 1];
117 }
118
119 generateConcatenatedValues(allOperandValues, shifts, possibleValues);
120 return;
121 }
122
123 // --- Fallback Case ---
124 // If the operation is not recognized, assume all possible values for its
125 // bitwidth.
126
127 auto addrType = dyn_cast<IntegerType>(v.getType());
128 if (!addrType)
129 return; // Not an integer type we can analyze
130
131 unsigned bitWidth = addrType.getWidth();
132 // Again, use a threshold to avoid trying to enumerate 2^64 values.
133 if (bitWidth > 16) {
134 if (v.getDefiningOp())
135 v.getDefiningOp()->emitWarning()
136 << "Bitwidth " << bitWidth
137 << " too large (>16); abandoning analysis for this path";
138 return;
139 }
140
141 uint64_t numRegStates = 1ULL << bitWidth;
142 for (size_t i = 0; i < numRegStates; i++) {
143 possibleValues.insert(i);
144 }
145}
146
147static void addPossibleValues(llvm::SetVector<size_t> &possibleValues,
148 Value v) {
149 llvm::DenseSet<Value> visited;
150 addPossibleValuesImpl(possibleValues, v, visited);
151}
152
153/// Checks if two values compute the same function structurally.
154/// Values defined outside their respective regions (e.g., machine block
155/// arguments or fsm.variable results) are compared by SSA identity.
156/// Values defined by operations within their regions are compared using
157/// MLIR's OperationEquivalence with a custom equivalence callback.
158static bool areStructurallyEquivalent(Value a, Value b, Region &regionA,
159 Region &regionB) {
160 if (a == b)
161 return true;
162
163 Operation *opA = a.getDefiningOp();
164 Operation *opB = b.getDefiningOp();
165 if (!opA || !opB)
166 return false;
167
168 bool aIsLocal = regionA.isAncestor(opA->getParentRegion());
169 bool bIsLocal = regionB.isAncestor(opB->getParentRegion());
170 if (aIsLocal != bIsLocal)
171 return false;
172 if (!aIsLocal)
173 return false;
174
175 // Both local: compare result index and delegate structural comparison
176 // to MLIR's OperationEquivalence.
177 if (cast<OpResult>(a).getResultNumber() !=
178 cast<OpResult>(b).getResultNumber())
179 return false;
180
181 return OperationEquivalence::isEquivalentTo(
182 opA, opB,
183 [&](Value lhs, Value rhs) -> LogicalResult {
184 return success(areStructurallyEquivalent(lhs, rhs, regionA, regionB));
185 },
186 /*markEquivalent=*/nullptr, OperationEquivalence::Flags::IgnoreLocations);
187}
188
189/// RewritePattern that folds expressions in an action region that are
190/// structurally equivalent to guard conditions into boolean constants.
191
192class GuardConditionFoldPattern : public RewritePattern {
193public:
194 GuardConditionFoldPattern(MLIRContext *ctx,
195 ArrayRef<std::pair<Value, bool>> guardFacts,
196 Region &guardRegion, Region &actionRegion)
197 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/10, ctx),
198 guardFacts(guardFacts.begin(), guardFacts.end()),
199 guardRegion(guardRegion), actionRegion(actionRegion) {}
200
201 LogicalResult matchAndRewrite(Operation *op,
202 PatternRewriter &rewriter) const override {
203 if (!actionRegion.isAncestor(op->getParentRegion()))
204 return failure();
205
206 // Skip constants because replacing a constant with an identical constant
207 // creates an infinite loop in the greedy rewriter.
208 if (isa<hw::ConstantOp>(op))
209 return failure();
210
211 for (Value result : op->getResults()) {
212 if (!result.getType().isInteger(1) || result.use_empty())
213 continue;
214
215 for (auto [guardExpr, isTrue] : guardFacts) {
216 // Direct structural match.
217 if (areStructurallyEquivalent(guardExpr, result, guardRegion,
218 actionRegion))
219 return replaceWithConstant(result, isTrue, op, rewriter);
220
221 // ICmp predicate inversion: if the guard contains icmp P(X, Y),
222 // find icmp !P(X, Y) in the action and replace with !isTrue.
223 if (auto guardIcmp = guardExpr.getDefiningOp<ICmpOp>()) {
224 if (auto actionIcmp = dyn_cast<ICmpOp>(op)) {
225 if (actionIcmp.getPredicate() ==
226 ICmpOp::getNegatedPredicate(guardIcmp.getPredicate()) &&
227 areStructurallyEquivalent(guardIcmp.getLhs(),
228 actionIcmp.getLhs(), guardRegion,
229 actionRegion) &&
230 areStructurallyEquivalent(guardIcmp.getRhs(),
231 actionIcmp.getRhs(), guardRegion,
232 actionRegion))
233 return replaceWithConstant(result, !isTrue, op, rewriter);
234 }
235 }
236 }
237 }
238 return failure();
239 }
240
241private:
242 LogicalResult replaceWithConstant(Value result, bool constVal, Operation *op,
243 PatternRewriter &rewriter) const {
244 rewriter.setInsertionPointToStart(&actionRegion.front());
245 Value c = hw::ConstantOp::create(rewriter, op->getLoc(),
246 rewriter.getI1Type(), constVal ? 1 : 0);
247 rewriter.replaceAllUsesWith(result, c);
248 if (op->use_empty())
249 rewriter.eraseOp(op);
250 return success();
251 }
252
253 SmallVector<std::pair<Value, bool>> guardFacts;
254 Region &guardRegion;
255 Region &actionRegion;
256};
257
258/// Simplify action blocks by propagating guard conditions.
259/// When a transition is taken, its guard condition is known to be true.
260/// This function finds expressions in the action region that are structurally
261/// identical to the guard condition and replaces them with constant true.
262/// It also handles NOT (xor with true) and AND decomposition to propagate
263/// additional known values.
264static void simplifyActionWithGuard(TransitionOp transition,
265 OpBuilder &builder) {
266 Region &guardRegion = transition.getGuard();
267 Region &actionRegion = transition.getAction();
268
269 if (guardRegion.empty() || actionRegion.empty())
270 return;
271
272 auto guardReturn =
273 dyn_cast<fsm::ReturnOp>(guardRegion.front().getTerminator());
274 if (!guardReturn)
275 return;
276
277 Location loc = guardReturn.getLoc();
278 Value guardCondition = guardReturn.getOperand();
279
280 // Collect (expression, isTrue) pairs to propagate into the action.
281 SmallVector<std::pair<Value, bool>> guardFacts;
282
283 // Use a worklist to decompose AND expressions.
284 SmallVector<std::pair<Value, bool>> worklist;
285 worklist.push_back({guardCondition, true});
286
287 while (!worklist.empty()) {
288 auto [cond, isTrue] = worklist.pop_back_val();
289 guardFacts.push_back({cond, isTrue});
290
291 // Decompose NOT: xor(E, true) means E has the opposite truth value.
292 if (auto xorOp = cond.getDefiningOp<XorOp>()) {
293 if (xorOp.isBinaryNot() &&
294 guardRegion.isAncestor(xorOp->getParentRegion()))
295 guardFacts.push_back({xorOp.getOperand(0), !isTrue});
296 }
297
298 // Decompose AND: if the AND is true, each operand is true.
299 if (isTrue) {
300 if (auto andOp = cond.getDefiningOp<AndOp>()) {
301 if (guardRegion.isAncestor(andOp->getParentRegion())) {
302 for (Value operand : andOp.getOperands())
303 worklist.push_back({operand, true});
304 }
305 }
306 }
307 }
308
309 // Replace external guard values (block args or ops outside guard region)
310 // directly via replaceUsesWithIf.
311 for (auto [guardExpr, isTrue] : guardFacts) {
312 bool guardIsExternal =
313 !guardExpr.getDefiningOp() ||
314 !guardRegion.isAncestor(guardExpr.getDefiningOp()->getParentRegion());
315 if (guardIsExternal) {
316 builder.setInsertionPointToStart(&actionRegion.front());
317 auto constOp = hw::ConstantOp::create(builder, loc, guardExpr.getType(),
318 isTrue ? 1 : 0);
319 guardExpr.replaceUsesWithIf(constOp.getResult(), [&](OpOperand &use) {
320 return actionRegion.isAncestor(use.getOwner()->getParentRegion());
321 });
322 }
323 }
324
325 // Use pattern rewriter for internal guard expression folding.
326 MLIRContext *ctx = transition.getContext();
327 RewritePatternSet patterns(ctx);
328 patterns.add<GuardConditionFoldPattern>(ctx, guardFacts, guardRegion,
329 actionRegion);
330 SmallVector<Operation *> actionOps;
331 actionRegion.walk([&](Operation *op) { actionOps.push_back(op); });
332 GreedyRewriteConfig config;
333 config.setScope(&actionRegion);
334 (void)applyOpPatternsGreedily(
335 actionOps, FrozenRewritePatternSet(std::move(patterns)), config);
336}
337
338/// Checks if a value is a constant or a tree of muxes with constant leaves.
339/// Uses an iterative approach with a visited set to handle cycles.
340static bool isConstantOrConstantTree(Value value) {
341 SmallVector<Value> worklist;
342 llvm::DenseSet<Value> visited;
343
344 worklist.push_back(value);
345 while (!worklist.empty()) {
346 Value current = worklist.pop_back_val();
347
348 // Skip if already visited (handles cycles).
349 if (!visited.insert(current).second)
350 continue;
351
352 Operation *definingOp = current.getDefiningOp();
353 if (!definingOp)
354 return false;
355
356 if (isa<hw::ConstantOp>(definingOp))
357 continue;
358
359 if (auto muxOp = dyn_cast<MuxOp>(definingOp)) {
360 worklist.push_back(muxOp.getTrueValue());
361 worklist.push_back(muxOp.getFalseValue());
362 continue;
363 }
364
365 // Not a constant or mux - not constant-like.
366 return false;
367 }
368 return true;
369}
370
371/// Pushes an ICmp equality comparison through a mux operation.
372/// This transforms `icmp eq (mux cond, x, y), b` into
373/// `mux cond, (icmp eq x, b), (icmp eq y, b)`.
374/// This simplification helps expose constant comparisons that can be folded
375/// during FSM extraction, making transition guards easier to analyze.
376LogicalResult pushIcmp(ICmpOp op, PatternRewriter &rewriter) {
377 APInt lhs, rhs;
378 if (op.getPredicate() == ICmpPredicate::eq &&
379 op.getLhs().getDefiningOp<MuxOp>() &&
380 (isConstantOrConstantTree(op.getLhs()) ||
381 op.getRhs().getDefiningOp<hw::ConstantOp>())) {
382 rewriter.setInsertionPointAfter(op);
383 auto mux = op.getLhs().getDefiningOp<MuxOp>();
384 Value x = mux.getTrueValue();
385 Value y = mux.getFalseValue();
386 Value b = op.getRhs();
387 Location loc = op.getLoc();
388 auto eq1 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, x, b);
389 auto eq2 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, y, b);
390 rewriter.replaceOpWithNewOp<MuxOp>(op, mux.getCond(), eq1.getResult(),
391 eq2.getResult());
392 return llvm::success();
393 }
394 if (op.getPredicate() == ICmpPredicate::eq &&
395 op.getRhs().getDefiningOp<MuxOp>() &&
396 (isConstantOrConstantTree(op.getRhs()) ||
397 op.getLhs().getDefiningOp<hw::ConstantOp>())) {
398 rewriter.setInsertionPointAfter(op);
399 auto mux = op.getRhs().getDefiningOp<MuxOp>();
400 Value x = mux.getTrueValue();
401 Value y = mux.getFalseValue();
402 Value b = op.getLhs();
403 Location loc = op.getLoc();
404 auto eq1 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, x, b);
405 auto eq2 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, y, b);
406 rewriter.replaceOpWithNewOp<MuxOp>(op, mux.getCond(), eq1.getResult(),
407 eq2.getResult());
408 return llvm::success();
409 }
410 return llvm::failure();
411}
412
413/// Iteratively builds all possible concatenated integer values from the
414/// Cartesian product of value sets.
415static void generateConcatenatedValues(
416 const llvm::SmallVector<llvm::SetVector<size_t>> &allOperandValues,
417 const llvm::SmallVector<unsigned> &shifts,
418 llvm::SetVector<size_t> &finalPossibleValues) {
419
420 if (allOperandValues.empty()) {
421 finalPossibleValues.insert(0);
422 return;
423 }
424
425 // Start with the values of the first operand, shifted appropriately.
426 llvm::SetVector<size_t> currentResults;
427 for (size_t val : allOperandValues[0])
428 currentResults.insert(val << shifts[0]);
429
430 // For each subsequent operand, combine with all existing partial results.
431 for (size_t operandIdx = 1; operandIdx < allOperandValues.size();
432 ++operandIdx) {
433 llvm::SetVector<size_t> nextResults;
434 unsigned shift = shifts[operandIdx];
435
436 for (size_t partialValue : currentResults) {
437 for (size_t val : allOperandValues[operandIdx]) {
438 nextResults.insert(partialValue | (val << shift));
439 }
440 }
441 currentResults = std::move(nextResults);
442 }
443
444 finalPossibleValues = std::move(currentResults);
445}
446
447static llvm::MapVector<Value, int> intToRegMap(SmallVector<seq::CompRegOp> v,
448 int i) {
449 llvm::MapVector<Value, int> m;
450 for (size_t ci = 0; ci < v.size(); ci++) {
451 seq::CompRegOp reg = v[ci];
452 int bits = reg.getType().getIntOrFloatBitWidth();
453 int v = i & ((1 << bits) - 1);
454 m[reg] = v;
455 i = i >> bits;
456 }
457 return m;
458}
459
460static int regMapToInt(SmallVector<seq::CompRegOp> v,
461 llvm::DenseMap<Value, int> m) {
462 int i = 0;
463 int width = 0;
464 for (size_t ci = 0; ci < v.size(); ci++) {
465 seq::CompRegOp reg = v[ci];
466 i += m[reg] * 1ULL << width;
467 width += (reg.getType().getIntOrFloatBitWidth());
468 }
469 return i;
470}
471
472/// Computes the Cartesian product of a list of sets.
473static std::set<llvm::SmallVector<size_t>> calculateCartesianProduct(
474 const llvm::SmallVector<llvm::SetVector<size_t>> &valueSets) {
475 std::set<llvm::SmallVector<size_t>> product;
476 if (valueSets.empty()) {
477 // The Cartesian product of zero sets is a set containing one element:
478 // the empty tuple (represented here by an empty vector).
479 product.insert({});
480 return product;
481 }
482
483 // Initialize the product with the elements of the first set, each in its
484 // own vector.
485 for (size_t value : valueSets.front()) {
486 product.insert({value});
487 }
488
489 // Iteratively build the product. For each subsequent set, create a new
490 // temporary product by appending each of its elements to every combination
491 // already generated.
492 for (size_t i = 1; i < valueSets.size(); ++i) {
493 const auto &currentSet = valueSets[i];
494 if (currentSet.empty()) {
495 // The Cartesian product with an empty set results in an empty set.
496 return {};
497 }
498
499 std::set<llvm::SmallVector<size_t>> newProduct;
500 for (const auto &existingVector : product) {
501 for (size_t newValue : currentSet) {
502 llvm::SmallVector<size_t> newVector = existingVector;
503 newVector.push_back(newValue);
504 newProduct.insert(std::move(newVector));
505 }
506 }
507 product = std::move(newProduct);
508 }
509
510 return product;
511}
512
513static FrozenRewritePatternSet loadPatterns(MLIRContext &context) {
514
515 RewritePatternSet patterns(&context);
516 for (auto *dialect : context.getLoadedDialects())
517 dialect->getCanonicalizationPatterns(patterns);
518 ICmpOp::getCanonicalizationPatterns(patterns, &context);
519 AndOp::getCanonicalizationPatterns(patterns, &context);
520 XorOp::getCanonicalizationPatterns(patterns, &context);
521 MuxOp::getCanonicalizationPatterns(patterns, &context);
522 ConcatOp::getCanonicalizationPatterns(patterns, &context);
523 ExtractOp::getCanonicalizationPatterns(patterns, &context);
524 AddOp::getCanonicalizationPatterns(patterns, &context);
525 OrOp::getCanonicalizationPatterns(patterns, &context);
526 MulOp::getCanonicalizationPatterns(patterns, &context);
527 hw::ConstantOp::getCanonicalizationPatterns(patterns, &context);
528 TransitionOp::getCanonicalizationPatterns(patterns, &context);
529 StateOp::getCanonicalizationPatterns(patterns, &context);
530 MachineOp::getCanonicalizationPatterns(patterns, &context);
531 patterns.add(pushIcmp);
532 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
533 return frozenPatterns;
534}
535
536static LogicalResult
537getReachableStates(llvm::SetVector<size_t> &visitableStates,
538 HWModuleOp moduleOp, size_t currentStateIndex,
539 SmallVector<seq::CompRegOp> registers, OpBuilder opBuilder,
540 bool isInitialState) {
541
542 IRMapping mapping;
543 auto clonedBody =
544 llvm::dyn_cast<HWModuleOp>(opBuilder.clone(*moduleOp, mapping));
545
546 llvm::MapVector<Value, int> stateMap =
547 intToRegMap(registers, currentStateIndex);
548 Operation *terminator = clonedBody.getBody().front().getTerminator();
549 auto output = dyn_cast<hw::OutputOp>(terminator);
550 SmallVector<Value> values;
551
552 for (auto [originalRegValue, constStateValue] : stateMap) {
553
554 Value clonedRegValue = mapping.lookup(originalRegValue);
555 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
556 auto reg = cast<seq::CompRegOp>(clonedRegOp);
557 Type constantType = reg.getType();
558 IntegerAttr constantAttr =
559 opBuilder.getIntegerAttr(constantType, constStateValue);
560 opBuilder.setInsertionPoint(clonedRegOp);
561 auto otherStateConstant =
562 hw::ConstantOp::create(opBuilder, reg.getLoc(), constantAttr);
563 // If the register input is self-referential (input == output), use the
564 // constant we're replacing it with. Otherwise, the value would become
565 // dangling after we erase the register.
566 Value regInput = reg.getInput();
567 if (regInput == clonedRegValue)
568 values.push_back(otherStateConstant.getResult());
569 else
570 values.push_back(regInput);
571 clonedRegValue.replaceAllUsesWith(otherStateConstant.getResult());
572 reg.erase();
573 }
574 opBuilder.setInsertionPointToEnd(clonedBody.front().getBlock());
575 auto newOutput = hw::OutputOp::create(opBuilder, output.getLoc(), values);
576 output.erase();
577 FrozenRewritePatternSet frozenPatterns = loadPatterns(*moduleOp.getContext());
578
579 SmallVector<Operation *> opsToProcess;
580 clonedBody.walk([&](Operation *op) { opsToProcess.push_back(op); });
581
582 bool changed = false;
583 GreedyRewriteConfig config;
584 if (failed(applyOpPatternsGreedily(opsToProcess, frozenPatterns, config,
585 &changed)))
586 return failure();
587
588 llvm::SmallVector<llvm::SetVector<size_t>> pv;
589 for (size_t j = 0; j < newOutput.getNumOperands(); j++) {
590 llvm::SetVector<size_t> possibleValues;
591
592 Value v = newOutput.getOperand(j);
593 addPossibleValues(possibleValues, v);
594 pv.push_back(possibleValues);
595 }
596 std::set<llvm::SmallVector<size_t>> flipped = calculateCartesianProduct(pv);
597 for (llvm::SmallVector<size_t> v : flipped) {
598 llvm::DenseMap<Value, int> m;
599 for (size_t k = 0; k < v.size(); k++) {
600 seq::CompRegOp r = registers[k];
601 m[r] = v[k];
602 }
603
604 int i = regMapToInt(registers, m);
605 visitableStates.insert(i);
606 }
607
608 clonedBody.erase();
609 return success();
610}
611
612// A converter class to handle the logic of converting a single hw.module.
613class HWModuleOpConverter {
614public:
615 HWModuleOpConverter(OpBuilder &builder, HWModuleOp moduleOp,
616 ArrayRef<std::string> stateRegNames)
617 : moduleOp(moduleOp), opBuilder(builder), stateRegNames(stateRegNames) {}
618 LogicalResult run() {
619 SmallVector<seq::CompRegOp> stateRegs;
620 SmallVector<seq::CompRegOp> variableRegs;
621 WalkResult walkResult = moduleOp.walk([&](seq::CompRegOp reg) {
622 // Check that the register type is an integer.
623 if (!isa<IntegerType>(reg.getType())) {
624 reg.emitError("FSM extraction only supports integer-typed registers");
625 return WalkResult::interrupt();
626 }
627 if (isStateRegister(reg)) {
628 stateRegs.push_back(reg);
629 } else {
630 variableRegs.push_back(reg);
631 }
632 return WalkResult::advance();
633 });
634 if (walkResult.wasInterrupted())
635 return failure();
636 if (stateRegs.empty()) {
637 emitError(moduleOp.getLoc())
638 << "Cannot find state register in this FSM. Use the state-regs "
639 "option to specify which registers are state registers.";
640 return failure();
641 }
642 SmallVector<seq::CompRegOp> registers;
643 for (seq::CompRegOp c : stateRegs) {
644 registers.push_back(c);
645 }
646
647 llvm::DenseMap<size_t, StateOp> stateToStateOp;
648 llvm::DenseMap<StateOp, size_t> stateOpToState;
649 // Collect reset arguments to exclude from the FSM's function type.
650 // All CompReg reset signals are ignored during FSM extraction since the
651 // FSM dialect does not have an explicit reset concept. The reset behavior
652 // is only captured in the initial state value.
653 llvm::DenseSet<size_t> asyncResetArguments;
654 Location loc = moduleOp.getLoc();
655 SmallVector<Type> inputTypes = moduleOp.getInputTypes();
656
657 // Create a new FSM machine with the current state.
658 auto resultTypes = moduleOp.getOutputTypes();
659 FunctionType machineType =
660 FunctionType::get(opBuilder.getContext(), inputTypes, resultTypes);
661 StringRef machineName = moduleOp.getName();
662
663 llvm::DenseMap<Value, int> initialStateMap;
664 for (seq::CompRegOp reg : moduleOp.getOps<seq::CompRegOp>()) {
665 Value resetValue = reg.getResetValue();
666 hw::ConstantOp definingConstant;
667 if (resetValue) {
668 definingConstant = resetValue.getDefiningOp<hw::ConstantOp>();
669 } else {
670 // Assume that registers without a reset start at 0
671 reg.emitWarning("Assuming register with no reset starts with value 0");
672 definingConstant =
673 hw::ConstantOp::create(opBuilder, reg.getLoc(), reg.getType(), 0);
674 }
675 if (!definingConstant) {
676 reg->emitError(
677 "cannot find defining constant for reset value of register");
678 return failure();
679 }
680 int resetValueInt =
681 definingConstant.getValueAttr().getValue().getZExtValue();
682 initialStateMap[reg] = resetValueInt;
683 }
684 int initialStateIndex = regMapToInt(registers, initialStateMap);
685
686 std::string initialStateName = "state_" + std::to_string(initialStateIndex);
687
688 // Preserve argument and result names, which are stored as attributes.
689 SmallVector<NamedAttribute> machineAttrs;
690 if (auto argNames = moduleOp->getAttrOfType<ArrayAttr>("argNames"))
691 machineAttrs.emplace_back(opBuilder.getStringAttr("argNames"), argNames);
692 if (auto resNames = moduleOp->getAttrOfType<ArrayAttr>("resultNames"))
693 machineAttrs.emplace_back(opBuilder.getStringAttr("resNames"), resNames);
694
695 // The builder for fsm.MachineOp will create the body region and block
696 // arguments.
697 opBuilder.setInsertionPoint(moduleOp);
698 auto machine =
699 MachineOp::create(opBuilder, loc, machineName, initialStateName,
700 machineType, machineAttrs);
701
702 OpBuilder::InsertionGuard guard(opBuilder);
703 opBuilder.setInsertionPointToStart(&machine.getBody().front());
704 llvm::MapVector<seq::CompRegOp, VariableOp> variableMap;
705 for (seq::CompRegOp varReg : variableRegs) {
706 TypedValue<Type> initialValue = varReg.getResetValue();
707 hw::ConstantOp definingConstant;
708 if (initialValue) {
709 definingConstant = initialValue.getDefiningOp<hw::ConstantOp>();
710 } else {
711 // Assume that registers without a reset start at 0
712 varReg.emitWarning(
713 "Assuming register with no reset starts with value 0");
714 definingConstant = hw::ConstantOp::create(opBuilder, varReg.getLoc(),
715 varReg.getType(), 0);
716 }
717 if (!definingConstant) {
718 varReg->emitError("cannot find defining constant for reset value of "
719 "variable register");
720 return failure();
721 }
722 auto variableOp = VariableOp::create(
723 opBuilder, varReg->getLoc(), varReg.getInput().getType(),
724 definingConstant.getValueAttr(), varReg.getName().value_or("var"));
725 variableMap[varReg] = variableOp;
726 }
727
728 // Load rewrite patterns used for canonicalizing the generated FSM.
729 FrozenRewritePatternSet frozenPatterns =
730 loadPatterns(*moduleOp.getContext());
731
732 SetVector<int> reachableStates;
733 SmallVector<int> worklist;
734
735 worklist.push_back(initialStateIndex);
736 reachableStates.insert(initialStateIndex);
737 // Process states in BFS order. The worklist grows as new reachable states
738 // are discovered, so we use an index-based loop.
739 for (unsigned i = 0; i < worklist.size(); ++i) {
740
741 int currentStateIndex = worklist[i];
742
743 llvm::MapVector<Value, int> stateMap =
744 intToRegMap(registers, currentStateIndex);
745
746 opBuilder.setInsertionPointToEnd(&machine.getBody().front());
747
748 StateOp stateOp;
749
750 if (!stateToStateOp.contains(currentStateIndex)) {
751 stateOp = StateOp::create(opBuilder, loc,
752 "state_" + std::to_string(currentStateIndex));
753 stateToStateOp.insert({currentStateIndex, stateOp});
754 stateOpToState.insert({stateOp, currentStateIndex});
755 } else {
756 stateOp = stateToStateOp.lookup(currentStateIndex);
757 }
758 Region &outputRegion = stateOp.getOutput();
759 Block *outputBlock = &outputRegion.front();
760 opBuilder.setInsertionPointToStart(outputBlock);
761 IRMapping mapping;
762 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), outputRegion,
763 outputBlock->getIterator(), mapping);
764 outputBlock->erase();
765
766 auto *terminator = outputRegion.front().getTerminator();
767 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
768 assert(hwOutputOp && "Expected terminator to be hw.output op");
769
770 // Position the builder to insert the new terminator right before the
771 // old one.
772 OpBuilder::InsertionGuard stateGuard(opBuilder);
773 opBuilder.setInsertionPoint(hwOutputOp);
774
775 // Create the new fsm.OutputOp with the same operands.
776
777 fsm::OutputOp::create(opBuilder, hwOutputOp.getLoc(),
778 hwOutputOp.getOperands());
779
780 // Erase the old terminator.
781 hwOutputOp.erase();
782
783 // Iterate through the state configuration to replace registers
784 // with constants.
785 for (auto &[originalRegValue, variableOp] : variableMap) {
786 Value clonedRegValue = mapping.lookup(originalRegValue);
787 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
788 auto reg = cast<seq::CompRegOp>(clonedRegOp);
789 const auto res = variableOp.getResult();
790 clonedRegValue.replaceAllUsesWith(res);
791 reg.erase();
792 }
793 for (auto const &[originalRegValue, constStateValue] : stateMap) {
794 // Find the cloned register's result value using the mapping.
795 Value clonedRegValue = mapping.lookup(originalRegValue);
796 assert(clonedRegValue &&
797 "Original register value not found in mapping");
798 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
799
800 assert(clonedRegOp && "Cloned value must have a defining op");
801 opBuilder.setInsertionPoint(clonedRegOp);
802 auto r = cast<seq::CompRegOp>(clonedRegOp);
803 TypedValue<IntegerType> registerReset = r.getReset();
804 if (registerReset) {
805 if (BlockArgument blockArg = dyn_cast<BlockArgument>(registerReset)) {
806 asyncResetArguments.insert(blockArg.getArgNumber());
807 auto falseConst = hw::ConstantOp::create(
808 opBuilder, blockArg.getLoc(), clonedRegValue.getType(), 0);
809 blockArg.replaceAllUsesWith(falseConst.getResult());
810 }
811 if (auto xorOp = registerReset.getDefiningOp<XorOp>()) {
812 if (xorOp.isBinaryNot()) {
813 Value rhs = xorOp.getOperand(0);
814 if (BlockArgument blockArg = dyn_cast<BlockArgument>(rhs)) {
815 asyncResetArguments.insert(blockArg.getArgNumber());
816 auto trueConst = hw::ConstantOp::create(
817 opBuilder, blockArg.getLoc(), blockArg.getType(), 1);
818 blockArg.replaceAllUsesWith(trueConst.getResult());
819 }
820 }
821 }
822 }
823 auto constantOp =
824 hw::ConstantOp::create(opBuilder, clonedRegValue.getLoc(),
825 clonedRegValue.getType(), constStateValue);
826 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
827 clonedRegOp->erase();
828 }
829 GreedyRewriteConfig config;
830 SmallVector<Operation *> opsToProcess;
831 outputRegion.walk([&](Operation *op) { opsToProcess.push_back(op); });
832 // Replace references to arguments in the output block with
833 // arguments at the top level.
834 for (auto arg : outputRegion.front().getArguments()) {
835 int argIndex = arg.getArgNumber();
836 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
837 arg.replaceAllUsesWith(topLevelArg);
838 }
839 outputRegion.front().eraseArguments(
840 [](BlockArgument arg) { return true; });
841 FrozenRewritePatternSet patterns(opBuilder.getContext());
842 config.setScope(&outputRegion);
843
844 bool changed = false;
845 if (failed(applyOpPatternsGreedily(opsToProcess, patterns, config,
846 &changed)))
847 return failure();
848 opBuilder.setInsertionPoint(stateOp);
849 // hw.module uses graph regions that allow cycles (e.g., registers feeding
850 // back into themselves). By this point we've replaced all registers with
851 // constants, but cycles in purely combinational logic (e.g., cyclic
852 // muxes) may still exist. Such cycles cannot be converted to FSM.
853 bool sorted = sortTopologically(&outputRegion.front());
854 if (!sorted) {
855 moduleOp.emitError()
856 << "cannot convert module with combinational cycles to FSM";
857 return failure();
858 }
859 Region &transitionRegion = stateOp.getTransitions();
860 llvm::SetVector<size_t> visitableStates;
861 if (failed(getReachableStates(visitableStates, moduleOp,
862 currentStateIndex, registers, opBuilder,
863 currentStateIndex == initialStateIndex)))
864 return failure();
865 for (size_t j : visitableStates) {
866 StateOp toState;
867 if (!stateToStateOp.contains(j)) {
868 opBuilder.setInsertionPointToEnd(&machine.getBody().front());
869 toState =
870 StateOp::create(opBuilder, loc, "state_" + std::to_string(j));
871 stateToStateOp.insert({j, toState});
872 stateOpToState.insert({toState, j});
873 } else {
874 toState = stateToStateOp[j];
875 }
876 opBuilder.setInsertionPointToStart(&transitionRegion.front());
877 auto transitionOp =
878 TransitionOp::create(opBuilder, loc, "state_" + std::to_string(j));
879 Region &guardRegion = transitionOp.getGuard();
880 opBuilder.createBlock(&guardRegion);
881
882 Block &guardBlock = guardRegion.front();
883
884 opBuilder.setInsertionPointToStart(&guardBlock);
885 IRMapping mapping;
886 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), guardRegion,
887 guardBlock.getIterator(), mapping);
888 guardBlock.erase();
889 Block &newGuardBlock = guardRegion.front();
890 Operation *terminator = newGuardBlock.getTerminator();
891 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
892 assert(hwOutputOp && "Expected terminator to be hw.output op");
893
894 llvm::MapVector<Value, int> toStateMap = intToRegMap(registers, j);
895 SmallVector<Value> equalityChecks;
896 for (auto &[originalRegValue, variableOp] : variableMap) {
897 opBuilder.setInsertionPointToStart(&newGuardBlock);
898 Value clonedRegValue = mapping.lookup(originalRegValue);
899 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
900 auto reg = cast<seq::CompRegOp>(clonedRegOp);
901 const auto res = variableOp.getResult();
902 clonedRegValue.replaceAllUsesWith(res);
903 reg.erase();
904 }
905 for (auto const &[originalRegValue, constStateValue] : toStateMap) {
906
907 Value clonedRegValue = mapping.lookup(originalRegValue);
908 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
909 opBuilder.setInsertionPoint(clonedRegOp);
910 auto r = cast<seq::CompRegOp>(clonedRegOp);
911
912 Value registerInput = r.getInput();
913 TypedValue<IntegerType> registerReset = r.getReset();
914 if (registerReset) {
915 if (BlockArgument blockArg =
916 dyn_cast<BlockArgument>(registerReset)) {
917 auto falseConst = hw::ConstantOp::create(
918 opBuilder, blockArg.getLoc(), clonedRegValue.getType(), 0);
919 blockArg.replaceAllUsesWith(falseConst.getResult());
920 }
921 if (auto xorOp = registerReset.getDefiningOp<XorOp>()) {
922 if (xorOp.isBinaryNot()) {
923 Value rhs = xorOp.getOperand(0);
924 if (BlockArgument blockArg = dyn_cast<BlockArgument>(rhs)) {
925 auto trueConst = hw::ConstantOp::create(
926 opBuilder, blockArg.getLoc(), blockArg.getType(), 1);
927 blockArg.replaceAllUsesWith(trueConst.getResult());
928 }
929 }
930 }
931 }
932 Type constantType = registerInput.getType();
933 IntegerAttr constantAttr =
934 opBuilder.getIntegerAttr(constantType, constStateValue);
935 auto otherStateConstant = hw::ConstantOp::create(
936 opBuilder, hwOutputOp.getLoc(), constantAttr);
937
938 auto doesEqual =
939 ICmpOp::create(opBuilder, hwOutputOp.getLoc(), ICmpPredicate::eq,
940 registerInput, otherStateConstant.getResult());
941 equalityChecks.push_back(doesEqual.getResult());
942 }
943 opBuilder.setInsertionPoint(hwOutputOp);
944 auto allEqualCheck = AndOp::create(opBuilder, hwOutputOp.getLoc(),
945 equalityChecks, false);
946 fsm::ReturnOp::create(opBuilder, hwOutputOp.getLoc(),
947 allEqualCheck.getResult());
948 hwOutputOp.erase();
949 for (BlockArgument arg : newGuardBlock.getArguments()) {
950 int argIndex = arg.getArgNumber();
951 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
952 arg.replaceAllUsesWith(topLevelArg);
953 }
954 newGuardBlock.eraseArguments([](BlockArgument arg) { return true; });
955 llvm::MapVector<Value, int> fromStateMap =
956 intToRegMap(registers, currentStateIndex);
957 for (auto const &[originalRegValue, constStateValue] : fromStateMap) {
958 Value clonedRegValue = mapping.lookup(originalRegValue);
959 assert(clonedRegValue &&
960 "Original register value not found in mapping");
961 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
962 assert(clonedRegOp && "Cloned value must have a defining op");
963 opBuilder.setInsertionPoint(clonedRegOp);
964 auto constantOp =
965 hw::ConstantOp::create(opBuilder, clonedRegValue.getLoc(),
966 clonedRegValue.getType(), constStateValue);
967 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
968 clonedRegOp->erase();
969 }
970 Region &actionRegion = transitionOp.getAction();
971 if (!variableRegs.empty()) {
972 Block *actionBlock = opBuilder.createBlock(&actionRegion);
973 opBuilder.setInsertionPointToStart(actionBlock);
974 IRMapping mapping;
975 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), actionRegion,
976 actionBlock->getIterator(), mapping);
977 actionBlock->erase();
978 Block &newActionBlock = actionRegion.front();
979 for (BlockArgument arg : newActionBlock.getArguments()) {
980 int argIndex = arg.getArgNumber();
981 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
982 arg.replaceAllUsesWith(topLevelArg);
983 }
984 newActionBlock.eraseArguments([](BlockArgument arg) { return true; });
985 for (auto &[originalRegValue, variableOp] : variableMap) {
986 Value clonedRegValue = mapping.lookup(originalRegValue);
987 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
988 auto reg = cast<seq::CompRegOp>(clonedRegOp);
989 opBuilder.setInsertionPointToStart(&newActionBlock);
990 UpdateOp::create(opBuilder, reg.getLoc(), variableOp,
991 reg.getInput());
992 const Value res = variableOp.getResult();
993 clonedRegValue.replaceAllUsesWith(res);
994 reg.erase();
995 }
996 Operation *terminator = actionRegion.back().getTerminator();
997 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
998 assert(hwOutputOp && "Expected terminator to be hw.output op");
999 hwOutputOp.erase();
1000
1001 for (auto const &[originalRegValue, constStateValue] : fromStateMap) {
1002 Value clonedRegValue = mapping.lookup(originalRegValue);
1003 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
1004 opBuilder.setInsertionPoint(clonedRegOp);
1005 auto constantOp = hw::ConstantOp::create(
1006 opBuilder, clonedRegValue.getLoc(), clonedRegValue.getType(),
1007 constStateValue);
1008 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
1009 clonedRegOp->erase();
1010 }
1011
1012 GreedyRewriteConfig config;
1013 SmallVector<Operation *> opsToProcess;
1014 actionRegion.walk([&](Operation *op) { opsToProcess.push_back(op); });
1015 config.setScope(&actionRegion);
1016
1017 bool changed = false;
1018 if (failed(applyOpPatternsGreedily(opsToProcess, frozenPatterns,
1019 config, &changed)))
1020 return failure();
1021
1022 // hw.module uses graph regions that allow cycles. By this point
1023 // we've replaced all registers with constants, but cycles in purely
1024 // combinational logic may still exist.
1025 bool actionSorted = sortTopologically(&actionRegion.front());
1026 if (!actionSorted) {
1027 moduleOp.emitError()
1028 << "cannot convert module with combinational cycles to FSM";
1029 return failure();
1030 }
1031 }
1032
1033 // hw.module uses graph regions that allow cycles. By this point
1034 // we've replaced all registers with constants, but cycles in purely
1035 // combinational logic may still exist.
1036 bool guardSorted = sortTopologically(&newGuardBlock);
1037 if (!guardSorted) {
1038 moduleOp.emitError()
1039 << "cannot convert module with combinational cycles to FSM";
1040 return failure();
1041 }
1042 SmallVector<Operation *> outputOps;
1043 stateOp.getOutput().walk(
1044 [&](Operation *op) { outputOps.push_back(op); });
1045
1046 bool changed = false;
1047 GreedyRewriteConfig config;
1048 config.setScope(&stateOp.getOutput());
1049 LogicalResult converged = applyOpPatternsGreedily(
1050 outputOps, frozenPatterns, config, &changed);
1051 assert(succeeded(converged) && "canonicalization failed to converge");
1052 SmallVector<Operation *> transitionOps;
1053 stateOp.getTransitions().walk(
1054 [&](Operation *op) { transitionOps.push_back(op); });
1055
1056 GreedyRewriteConfig config2;
1057 config2.setScope(&stateOp.getTransitions());
1058 if (failed(applyOpPatternsGreedily(transitionOps, frozenPatterns,
1059 config2, &changed))) {
1060 return failure();
1061 }
1062
1063 // Propagate guard conditions into action blocks to eliminate
1064 // redundant muxes. The guard already checks the transition
1065 // condition, so the action block can assume it holds.
1066 for (TransitionOp transition :
1067 stateOp.getTransitions().getOps<TransitionOp>())
1068 simplifyActionWithGuard(transition, opBuilder);
1069
1070 // Re-canonicalize after guard propagation to clean up dead ops.
1071 {
1072 SmallVector<Operation *> postOps;
1073 stateOp.getTransitions().walk(
1074 [&](Operation *op) { postOps.push_back(op); });
1075 GreedyRewriteConfig postConfig;
1076 postConfig.setScope(&stateOp.getTransitions());
1077 if (failed(applyOpPatternsGreedily(postOps, frozenPatterns,
1078 postConfig, &changed)))
1079 return failure();
1080 }
1081
1082 for (TransitionOp transition :
1083 stateOp.getTransitions().getOps<TransitionOp>()) {
1084 StateOp nextState = transition.getNextStateOp();
1085 int nextStateIndex = stateOpToState.lookup(nextState);
1086 auto guardConst = transition.getGuardReturn()
1087 .getOperand()
1088 .getDefiningOp<hw::ConstantOp>();
1089 bool nextStateIsReachable =
1090 !guardConst || (guardConst.getValueAttr().getInt() != 0);
1091 // If we find a valid next state and haven't seen it before, add it to
1092 // the worklist and the set of reachable states.
1093 if (nextStateIsReachable &&
1094 !reachableStates.contains(nextStateIndex)) {
1095 worklist.push_back(nextStateIndex);
1096 reachableStates.insert(nextStateIndex);
1097 }
1098 }
1099 }
1100 }
1101
1102 // Clean up unreachable states. States without an output region are
1103 // placeholder states that were created during reachability analysis but
1104 // never populated (i.e., they are unreachable from the initial state).
1105 SmallVector<StateOp> statesToErase;
1106
1107 // Collect unreachable states (those without an output op).
1108 for (StateOp stateOp : machine.getOps<StateOp>()) {
1109 if (!stateOp.getOutputOp()) {
1110 statesToErase.push_back(stateOp);
1111 }
1112 }
1113
1114 // Erase states in a separate loop to avoid iterator invalidation. We first
1115 // collect all states to erase, then iterate over that list. This is
1116 // necessary because erasing a state while iterating over machine.getOps()
1117 // would invalidate the iterator.
1118 for (StateOp stateOp : statesToErase) {
1119 for (TransitionOp transition : machine.getOps<TransitionOp>()) {
1120 if (transition.getNextStateOp().getSymName() == stateOp.getSymName()) {
1121 transition.erase();
1122 }
1123 }
1124 stateOp.erase();
1125 }
1126
1127 llvm::DenseSet<BlockArgument> asyncResetBlockArguments;
1128 for (auto arg : machine.getBody().front().getArguments()) {
1129 if (asyncResetArguments.contains(arg.getArgNumber())) {
1130 asyncResetBlockArguments.insert(arg);
1131 }
1132 }
1133
1134 // Emit a warning if reset signals were detected and removed.
1135 // The FSM dialect does not support reset signals, so the reset behavior
1136 // is only captured in the initial state. The original reset triggering
1137 // mechanism is not preserved.
1138 if (!asyncResetBlockArguments.empty()) {
1139 moduleOp.emitWarning()
1140 << "reset signals detected and removed from FSM; "
1141 "reset behavior is captured only in the initial state";
1142 }
1143
1144 Block &front = machine.getBody().front();
1145 front.eraseArguments([&](BlockArgument arg) {
1146 return asyncResetBlockArguments.contains(arg);
1147 });
1148 machine.getBody().front().eraseArguments([&](BlockArgument arg) {
1149 return arg.getType() == seq::ClockType::get(arg.getContext());
1150 });
1151 FunctionType oldFunctionType = machine.getFunctionType();
1152 SmallVector<Type> inputsWithoutClock;
1153 for (unsigned int i = 0; i < oldFunctionType.getNumInputs(); i++) {
1154 Type input = oldFunctionType.getInput(i);
1155 if (input != seq::ClockType::get(input.getContext()) &&
1156 !asyncResetArguments.contains(i))
1157 inputsWithoutClock.push_back(input);
1158 }
1159
1160 FunctionType newFunctionType = FunctionType::get(
1161 opBuilder.getContext(), inputsWithoutClock, resultTypes);
1162
1163 machine.setFunctionType(newFunctionType);
1164 moduleOp.erase();
1165 return success();
1166 }
1167
1168private:
1169 /// Helper function to determine if a register is a state register.
1170 bool isStateRegister(seq::CompRegOp reg) const {
1171 auto regName = reg.getName();
1172 if (!regName)
1173 return false;
1174
1175 // If user specified state registers, check if this register's name matches
1176 // any of them.
1177 if (!stateRegNames.empty()) {
1178 return llvm::is_contained(stateRegNames, regName->str());
1179 }
1180
1181 // Default behavior: infer state registers by checking if the name contains
1182 // "state".
1183 return regName->contains("state");
1184 }
1185
1186 HWModuleOp moduleOp;
1187 OpBuilder &opBuilder;
1188 ArrayRef<std::string> stateRegNames;
1189};
1190
1191} // namespace
1192
1193namespace {
1194struct CoreToFSMPass : public circt::impl::ConvertCoreToFSMBase<CoreToFSMPass> {
1195 using ConvertCoreToFSMBase<CoreToFSMPass>::ConvertCoreToFSMBase;
1196
1197 void runOnOperation() override {
1198 auto module = getOperation();
1199 OpBuilder builder(module);
1200
1201 SmallVector<HWModuleOp> modules;
1202 for (auto hwModule : module.getOps<HWModuleOp>()) {
1203 modules.push_back(hwModule);
1204 }
1205
1206 // Check for hw.instance operations - instance conversion is not supported.
1207 for (auto hwModule : modules) {
1208 for (auto instance : hwModule.getOps<hw::InstanceOp>()) {
1209 instance.emitError() << "instance conversion is not yet supported";
1210 signalPassFailure();
1211 return;
1212 }
1213 }
1214
1215 for (auto hwModule : modules) {
1216 builder.setInsertionPoint(hwModule);
1217 HWModuleOpConverter converter(builder, hwModule, stateRegs);
1218 if (failed(converter.run())) {
1219 signalPassFailure();
1220 return;
1221 }
1222 }
1223 }
1224};
1225} // namespace
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
create(*operands)
Definition fsm.py:160
create(name)
Definition fsm.py:144
create(to_state)
Definition fsm.py:110
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition comb.py:1
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition codegen.py:445
Definition fsm.py:1
Definition hw.py:1
Definition seq.py:1
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21