CIRCT 22.0.0git
Loading...
Searching...
No Matches
CoreToFSM.cpp
Go to the documentation of this file.
1//===CoreToFSM.cpp - Convert Core Dialects (HW + Seq + Comb) to FSM Dialect===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17#include "circt/Support/LLVM.h"
18#include "mlir/Analysis/TopologicalSortUtils.h"
19#include "mlir/IR/Block.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinAttributes.h"
22#include "mlir/IR/Diagnostics.h"
23#include "mlir/IR/IRMapping.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/IR/Value.h"
26#include "mlir/Pass/Pass.h"
27#include "mlir/Pass/PassManager.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#include "mlir/Transforms/Passes.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/DenseSet.h"
33#include "llvm/ADT/SmallVector.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/LogicalResult.h"
36#include <string>
37
38namespace circt {
39#define GEN_PASS_DEF_CONVERTCORETOFSM
40#include "circt/Conversion/Passes.h.inc"
41} // namespace circt
42
43using namespace mlir;
44using namespace circt;
45using namespace hw;
46using namespace comb;
47using namespace fsm;
48
49namespace {
50
51// Forward declaration for our helper function
52static void generateConcatenatedValues(
53 const llvm::SmallVector<llvm::SetVector<size_t>> &allOperandValues,
54 const llvm::SmallVector<unsigned> &shifts,
55 llvm::SetVector<size_t> &finalPossibleValues);
56
57/// Internal helper with visited set to detect cycles.
58static void addPossibleValuesImpl(llvm::SetVector<size_t> &possibleValues,
59 Value v, llvm::DenseSet<Value> &visited) {
60 // Detect cycles - if we've seen this value before, skip it.
61 if (!visited.insert(v).second)
62 return;
63
64 if (auto c = dyn_cast_or_null<hw::ConstantOp>(v.getDefiningOp())) {
65 possibleValues.insert(c.getValueAttr().getValue().getZExtValue());
66 return;
67 }
68 if (auto m = dyn_cast_or_null<MuxOp>(v.getDefiningOp())) {
69 addPossibleValuesImpl(possibleValues, m.getTrueValue(), visited);
70 addPossibleValuesImpl(possibleValues, m.getFalseValue(), visited);
71 return;
72 }
73
74 if (auto concatOp = dyn_cast_or_null<ConcatOp>(v.getDefiningOp())) {
75 llvm::SmallVector<llvm::SetVector<size_t>> allOperandValues;
76 llvm::SmallVector<unsigned> operandWidths;
77
78 for (Value operand : concatOp.getOperands()) {
79 llvm::SetVector<size_t> operandPossibleValues;
80 addPossibleValuesImpl(operandPossibleValues, operand, visited);
81
82 // It's crucial to handle the case where a sub-computation is too complex.
83 // If we can't determine specific values for an operand, we must
84 // pessimistically assume it can be any value its bitwidth allows.
85 auto opType = dyn_cast<IntegerType>(operand.getType());
86 // comb.concat only accepts signless integer operands by definition in
87 // CIRCT's type system, so this assertion should always hold for valid IR.
88 assert(opType && "comb.concat operand must be an integer type");
89 unsigned width = opType.getWidth();
90 if (operandPossibleValues.empty()) {
91 uint64_t numStates = 1ULL << width;
92 // Add a threshold to prevent combinatorial explosion on large unknown
93 // inputs.
94 if (numStates > 256) { // Heuristic threshold
95 // If the search space is too large, we abandon the analysis for this
96 // path. The outer function will fall back to its own full-range
97 // default.
98 v.getDefiningOp()->emitWarning()
99 << "Search space too large (>" << 256
100 << " states) for operand with bitwidth " << width
101 << "; abandoning analysis for this path";
102 return;
103 }
104 for (uint64_t i = 0; i < numStates; ++i)
105 operandPossibleValues.insert(i);
106 }
107
108 allOperandValues.push_back(operandPossibleValues);
109 operandWidths.push_back(width);
110 }
111
112 // The shift for operand `i` is the sum of the widths of operands `i+1` to
113 // `n-1`.
114 llvm::SmallVector<unsigned> shifts(concatOp.getNumOperands(), 0);
115 for (int i = concatOp.getNumOperands() - 2; i >= 0; --i) {
116 shifts[i] = shifts[i + 1] + operandWidths[i + 1];
117 }
118
119 generateConcatenatedValues(allOperandValues, shifts, possibleValues);
120 return;
121 }
122
123 // --- Fallback Case ---
124 // If the operation is not recognized, assume all possible values for its
125 // bitwidth.
126
127 auto addrType = dyn_cast<IntegerType>(v.getType());
128 if (!addrType)
129 return; // Not an integer type we can analyze
130
131 unsigned bitWidth = addrType.getWidth();
132 // Again, use a threshold to avoid trying to enumerate 2^64 values.
133 if (bitWidth > 16) {
134 if (v.getDefiningOp())
135 v.getDefiningOp()->emitWarning()
136 << "Bitwidth " << bitWidth
137 << " too large (>16); abandoning analysis for this path";
138 return;
139 }
140
141 uint64_t numRegStates = 1ULL << bitWidth;
142 for (size_t i = 0; i < numRegStates; i++) {
143 possibleValues.insert(i);
144 }
145}
146
147static void addPossibleValues(llvm::SetVector<size_t> &possibleValues,
148 Value v) {
149 llvm::DenseSet<Value> visited;
150 addPossibleValuesImpl(possibleValues, v, visited);
151}
152
153/// Checks if a value is a constant or a tree of muxes with constant leaves.
154/// Uses an iterative approach with a visited set to handle cycles.
155static bool isConstantOrConstantTree(Value value) {
156 SmallVector<Value> worklist;
157 llvm::DenseSet<Value> visited;
158
159 worklist.push_back(value);
160 while (!worklist.empty()) {
161 Value current = worklist.pop_back_val();
162
163 // Skip if already visited (handles cycles).
164 if (!visited.insert(current).second)
165 continue;
166
167 Operation *definingOp = current.getDefiningOp();
168 if (!definingOp)
169 return false;
170
171 if (isa<hw::ConstantOp>(definingOp))
172 continue;
173
174 if (auto muxOp = dyn_cast<MuxOp>(definingOp)) {
175 worklist.push_back(muxOp.getTrueValue());
176 worklist.push_back(muxOp.getFalseValue());
177 continue;
178 }
179
180 // Not a constant or mux - not constant-like.
181 return false;
182 }
183 return true;
184}
185
186/// Pushes an ICmp equality comparison through a mux operation.
187/// This transforms `icmp eq (mux cond, x, y), b` into
188/// `mux cond, (icmp eq x, b), (icmp eq y, b)`.
189/// This simplification helps expose constant comparisons that can be folded
190/// during FSM extraction, making transition guards easier to analyze.
191LogicalResult pushIcmp(ICmpOp op, PatternRewriter &rewriter) {
192 APInt lhs, rhs;
193 if (op.getPredicate() == ICmpPredicate::eq &&
194 op.getLhs().getDefiningOp<MuxOp>() &&
195 (isConstantOrConstantTree(op.getLhs()) ||
196 op.getRhs().getDefiningOp<hw::ConstantOp>())) {
197 rewriter.setInsertionPointAfter(op);
198 auto mux = op.getLhs().getDefiningOp<MuxOp>();
199 Value x = mux.getTrueValue();
200 Value y = mux.getFalseValue();
201 Value b = op.getRhs();
202 Location loc = op.getLoc();
203 auto eq1 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, x, b);
204 auto eq2 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, y, b);
205 rewriter.replaceOpWithNewOp<MuxOp>(op, mux.getCond(), eq1.getResult(),
206 eq2.getResult());
207 return llvm::success();
208 }
209 if (op.getPredicate() == ICmpPredicate::eq &&
210 op.getRhs().getDefiningOp<MuxOp>() &&
211 (isConstantOrConstantTree(op.getRhs()) ||
212 op.getLhs().getDefiningOp<hw::ConstantOp>())) {
213 rewriter.setInsertionPointAfter(op);
214 auto mux = op.getRhs().getDefiningOp<MuxOp>();
215 Value x = mux.getTrueValue();
216 Value y = mux.getFalseValue();
217 Value b = op.getLhs();
218 Location loc = op.getLoc();
219 auto eq1 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, x, b);
220 auto eq2 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, y, b);
221 rewriter.replaceOpWithNewOp<MuxOp>(op, mux.getCond(), eq1.getResult(),
222 eq2.getResult());
223 return llvm::success();
224 }
225 return llvm::failure();
226}
227
228/// Iteratively builds all possible concatenated integer values from the
229/// Cartesian product of value sets.
230static void generateConcatenatedValues(
231 const llvm::SmallVector<llvm::SetVector<size_t>> &allOperandValues,
232 const llvm::SmallVector<unsigned> &shifts,
233 llvm::SetVector<size_t> &finalPossibleValues) {
234
235 if (allOperandValues.empty()) {
236 finalPossibleValues.insert(0);
237 return;
238 }
239
240 // Start with the values of the first operand, shifted appropriately.
241 llvm::SetVector<size_t> currentResults;
242 for (size_t val : allOperandValues[0])
243 currentResults.insert(val << shifts[0]);
244
245 // For each subsequent operand, combine with all existing partial results.
246 for (size_t operandIdx = 1; operandIdx < allOperandValues.size();
247 ++operandIdx) {
248 llvm::SetVector<size_t> nextResults;
249 unsigned shift = shifts[operandIdx];
250
251 for (size_t partialValue : currentResults) {
252 for (size_t val : allOperandValues[operandIdx]) {
253 nextResults.insert(partialValue | (val << shift));
254 }
255 }
256 currentResults = std::move(nextResults);
257 }
258
259 finalPossibleValues = std::move(currentResults);
260}
261
262static llvm::MapVector<Value, int> intToRegMap(SmallVector<seq::CompRegOp> v,
263 int i) {
264 llvm::MapVector<Value, int> m;
265 for (size_t ci = 0; ci < v.size(); ci++) {
266 seq::CompRegOp reg = v[ci];
267 int bits = reg.getType().getIntOrFloatBitWidth();
268 int v = i & ((1 << bits) - 1);
269 m[reg] = v;
270 i = i >> bits;
271 }
272 return m;
273}
274
275static int regMapToInt(SmallVector<seq::CompRegOp> v,
276 llvm::DenseMap<Value, int> m) {
277 int i = 0;
278 int width = 0;
279 for (size_t ci = 0; ci < v.size(); ci++) {
280 seq::CompRegOp reg = v[ci];
281 i += m[reg] * 1ULL << width;
282 width += (reg.getType().getIntOrFloatBitWidth());
283 }
284 return i;
285}
286
287/// Computes the Cartesian product of a list of sets.
288static std::set<llvm::SmallVector<size_t>> calculateCartesianProduct(
289 const llvm::SmallVector<llvm::SetVector<size_t>> &valueSets) {
290 std::set<llvm::SmallVector<size_t>> product;
291 if (valueSets.empty()) {
292 // The Cartesian product of zero sets is a set containing one element:
293 // the empty tuple (represented here by an empty vector).
294 product.insert({});
295 return product;
296 }
297
298 // Initialize the product with the elements of the first set, each in its
299 // own vector.
300 for (size_t value : valueSets.front()) {
301 product.insert({value});
302 }
303
304 // Iteratively build the product. For each subsequent set, create a new
305 // temporary product by appending each of its elements to every combination
306 // already generated.
307 for (size_t i = 1; i < valueSets.size(); ++i) {
308 const auto &currentSet = valueSets[i];
309 if (currentSet.empty()) {
310 // The Cartesian product with an empty set results in an empty set.
311 return {};
312 }
313
314 std::set<llvm::SmallVector<size_t>> newProduct;
315 for (const auto &existingVector : product) {
316 for (size_t newValue : currentSet) {
317 llvm::SmallVector<size_t> newVector = existingVector;
318 newVector.push_back(newValue);
319 newProduct.insert(std::move(newVector));
320 }
321 }
322 product = std::move(newProduct);
323 }
324
325 return product;
326}
327
328static FrozenRewritePatternSet loadPatterns(MLIRContext &context) {
329
330 RewritePatternSet patterns(&context);
331 for (auto *dialect : context.getLoadedDialects())
332 dialect->getCanonicalizationPatterns(patterns);
333 ICmpOp::getCanonicalizationPatterns(patterns, &context);
334 AndOp::getCanonicalizationPatterns(patterns, &context);
335 XorOp::getCanonicalizationPatterns(patterns, &context);
336 MuxOp::getCanonicalizationPatterns(patterns, &context);
337 ConcatOp::getCanonicalizationPatterns(patterns, &context);
338 ExtractOp::getCanonicalizationPatterns(patterns, &context);
339 AddOp::getCanonicalizationPatterns(patterns, &context);
340 OrOp::getCanonicalizationPatterns(patterns, &context);
341 MulOp::getCanonicalizationPatterns(patterns, &context);
342 hw::ConstantOp::getCanonicalizationPatterns(patterns, &context);
343 TransitionOp::getCanonicalizationPatterns(patterns, &context);
344 StateOp::getCanonicalizationPatterns(patterns, &context);
345 MachineOp::getCanonicalizationPatterns(patterns, &context);
346 patterns.add(pushIcmp);
347 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
348 return frozenPatterns;
349}
350
351static LogicalResult
352getReachableStates(llvm::SetVector<size_t> &visitableStates,
353 HWModuleOp moduleOp, size_t currentStateIndex,
354 SmallVector<seq::CompRegOp> registers, OpBuilder opBuilder,
355 bool isInitialState) {
356
357 IRMapping mapping;
358 auto clonedBody =
359 llvm::dyn_cast<HWModuleOp>(opBuilder.clone(*moduleOp, mapping));
360
361 llvm::MapVector<Value, int> stateMap =
362 intToRegMap(registers, currentStateIndex);
363 Operation *terminator = clonedBody.getBody().front().getTerminator();
364 auto output = dyn_cast<hw::OutputOp>(terminator);
365 SmallVector<Value> values;
366
367 for (auto [originalRegValue, constStateValue] : stateMap) {
368
369 Value clonedRegValue = mapping.lookup(originalRegValue);
370 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
371 auto reg = cast<seq::CompRegOp>(clonedRegOp);
372 Type constantType = reg.getType();
373 IntegerAttr constantAttr =
374 opBuilder.getIntegerAttr(constantType, constStateValue);
375 opBuilder.setInsertionPoint(clonedRegOp);
376 auto otherStateConstant =
377 hw::ConstantOp::create(opBuilder, reg.getLoc(), constantAttr);
378 // If the register input is self-referential (input == output), use the
379 // constant we're replacing it with. Otherwise, the value would become
380 // dangling after we erase the register.
381 Value regInput = reg.getInput();
382 if (regInput == clonedRegValue)
383 values.push_back(otherStateConstant.getResult());
384 else
385 values.push_back(regInput);
386 clonedRegValue.replaceAllUsesWith(otherStateConstant.getResult());
387 reg.erase();
388 }
389 opBuilder.setInsertionPointToEnd(clonedBody.front().getBlock());
390 auto newOutput = hw::OutputOp::create(opBuilder, output.getLoc(), values);
391 output.erase();
392 FrozenRewritePatternSet frozenPatterns = loadPatterns(*moduleOp.getContext());
393
394 SmallVector<Operation *> opsToProcess;
395 clonedBody.walk([&](Operation *op) { opsToProcess.push_back(op); });
396
397 bool changed = false;
398 GreedyRewriteConfig config;
399 if (failed(applyOpPatternsGreedily(opsToProcess, frozenPatterns, config,
400 &changed)))
401 return failure();
402
403 llvm::SmallVector<llvm::SetVector<size_t>> pv;
404 for (size_t j = 0; j < newOutput.getNumOperands(); j++) {
405 llvm::SetVector<size_t> possibleValues;
406
407 Value v = newOutput.getOperand(j);
408 addPossibleValues(possibleValues, v);
409 pv.push_back(possibleValues);
410 }
411 std::set<llvm::SmallVector<size_t>> flipped = calculateCartesianProduct(pv);
412 for (llvm::SmallVector<size_t> v : flipped) {
413 llvm::DenseMap<Value, int> m;
414 for (size_t k = 0; k < v.size(); k++) {
415 seq::CompRegOp r = registers[k];
416 m[r] = v[k];
417 }
418
419 int i = regMapToInt(registers, m);
420 visitableStates.insert(i);
421 }
422
423 clonedBody.erase();
424 return success();
425}
426
427// A converter class to handle the logic of converting a single hw.module.
428class HWModuleOpConverter {
429public:
430 HWModuleOpConverter(OpBuilder &builder, HWModuleOp moduleOp,
431 ArrayRef<std::string> stateRegNames)
432 : moduleOp(moduleOp), opBuilder(builder), stateRegNames(stateRegNames) {}
433 LogicalResult run() {
434 SmallVector<seq::CompRegOp> stateRegs;
435 SmallVector<seq::CompRegOp> variableRegs;
436 WalkResult walkResult = moduleOp.walk([&](seq::CompRegOp reg) {
437 // Check that the register type is an integer.
438 if (!isa<IntegerType>(reg.getType())) {
439 reg.emitError("FSM extraction only supports integer-typed registers");
440 return WalkResult::interrupt();
441 }
442 if (isStateRegister(reg)) {
443 stateRegs.push_back(reg);
444 } else {
445 variableRegs.push_back(reg);
446 }
447 return WalkResult::advance();
448 });
449 if (walkResult.wasInterrupted())
450 return failure();
451 if (stateRegs.empty()) {
452 emitError(moduleOp.getLoc())
453 << "Cannot find state register in this FSM. Use the state-regs "
454 "option to specify which registers are state registers.";
455 return failure();
456 }
457 SmallVector<seq::CompRegOp> registers;
458 for (seq::CompRegOp c : stateRegs) {
459 registers.push_back(c);
460 }
461
462 llvm::DenseMap<size_t, StateOp> stateToStateOp;
463 llvm::DenseMap<StateOp, size_t> stateOpToState;
464 // Collect reset arguments to exclude from the FSM's function type.
465 // All CompReg reset signals are ignored during FSM extraction since the
466 // FSM dialect does not have an explicit reset concept. The reset behavior
467 // is only captured in the initial state value.
468 llvm::DenseSet<size_t> asyncResetArguments;
469 Location loc = moduleOp.getLoc();
470 SmallVector<Type> inputTypes = moduleOp.getInputTypes();
471
472 // Create a new FSM machine with the current state.
473 auto resultTypes = moduleOp.getOutputTypes();
474 FunctionType machineType =
475 FunctionType::get(opBuilder.getContext(), inputTypes, resultTypes);
476 StringRef machineName = moduleOp.getName();
477
478 llvm::DenseMap<Value, int> initialStateMap;
479 for (seq::CompRegOp reg : moduleOp.getOps<seq::CompRegOp>()) {
480 Value resetValue = reg.getResetValue();
481 auto definingConstant = resetValue.getDefiningOp<hw::ConstantOp>();
482 if (!definingConstant) {
483 reg->emitError(
484 "cannot find defining constant for reset value of register");
485 return failure();
486 }
487 int resetValueInt =
488 definingConstant.getValueAttr().getValue().getZExtValue();
489 initialStateMap[reg] = resetValueInt;
490 }
491 int initialStateIndex = regMapToInt(registers, initialStateMap);
492
493 std::string initialStateName = "state_" + std::to_string(initialStateIndex);
494
495 // Preserve argument and result names, which are stored as attributes.
496 SmallVector<NamedAttribute> machineAttrs;
497 if (auto argNames = moduleOp->getAttrOfType<ArrayAttr>("argNames"))
498 machineAttrs.emplace_back(opBuilder.getStringAttr("argNames"), argNames);
499 if (auto resNames = moduleOp->getAttrOfType<ArrayAttr>("resultNames"))
500 machineAttrs.emplace_back(opBuilder.getStringAttr("resNames"), resNames);
501
502 // The builder for fsm.MachineOp will create the body region and block
503 // arguments.
504 opBuilder.setInsertionPoint(moduleOp);
505 auto machine =
506 MachineOp::create(opBuilder, loc, machineName, initialStateName,
507 machineType, machineAttrs);
508
509 OpBuilder::InsertionGuard guard(opBuilder);
510 opBuilder.setInsertionPointToStart(&machine.getBody().front());
511 llvm::MapVector<seq::CompRegOp, VariableOp> variableMap;
512 for (seq::CompRegOp varReg : variableRegs) {
513 TypedValue<Type> initialValue = varReg.getResetValue();
514 auto definingConstant = initialValue.getDefiningOp<hw::ConstantOp>();
515 if (!definingConstant) {
516 varReg->emitError("cannot find defining constant for reset value of "
517 "variable register");
518 return failure();
519 }
520 auto variableOp = VariableOp::create(
521 opBuilder, varReg->getLoc(), varReg.getInput().getType(),
522 definingConstant.getValueAttr(), varReg.getName().value_or("var"));
523 variableMap[varReg] = variableOp;
524 }
525
526 // Load rewrite patterns used for canonicalizing the generated FSM.
527 FrozenRewritePatternSet frozenPatterns =
528 loadPatterns(*moduleOp.getContext());
529
530 SetVector<int> reachableStates;
531 SmallVector<int> worklist;
532
533 worklist.push_back(initialStateIndex);
534 reachableStates.insert(initialStateIndex);
535 // Process states in BFS order. The worklist grows as new reachable states
536 // are discovered, so we use an index-based loop.
537 for (unsigned i = 0; i < worklist.size(); ++i) {
538
539 int currentStateIndex = worklist[i];
540
541 llvm::MapVector<Value, int> stateMap =
542 intToRegMap(registers, currentStateIndex);
543
544 opBuilder.setInsertionPointToEnd(&machine.getBody().front());
545
546 StateOp stateOp;
547
548 if (!stateToStateOp.contains(currentStateIndex)) {
549 stateOp = StateOp::create(opBuilder, loc,
550 "state_" + std::to_string(currentStateIndex));
551 stateToStateOp.insert({currentStateIndex, stateOp});
552 stateOpToState.insert({stateOp, currentStateIndex});
553 } else {
554 stateOp = stateToStateOp.lookup(currentStateIndex);
555 }
556 Region &outputRegion = stateOp.getOutput();
557 Block *outputBlock = &outputRegion.front();
558 opBuilder.setInsertionPointToStart(outputBlock);
559 IRMapping mapping;
560 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), outputRegion,
561 outputBlock->getIterator(), mapping);
562 outputBlock->erase();
563
564 auto *terminator = outputRegion.front().getTerminator();
565 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
566 assert(hwOutputOp && "Expected terminator to be hw.output op");
567
568 // Position the builder to insert the new terminator right before the
569 // old one.
570 OpBuilder::InsertionGuard stateGuard(opBuilder);
571 opBuilder.setInsertionPoint(hwOutputOp);
572
573 // Create the new fsm.OutputOp with the same operands.
574
575 fsm::OutputOp::create(opBuilder, hwOutputOp.getLoc(),
576 hwOutputOp.getOperands());
577
578 // Erase the old terminator.
579 hwOutputOp.erase();
580
581 // Iterate through the state configuration to replace registers
582 // with constants.
583 for (auto &[originalRegValue, variableOp] : variableMap) {
584 Value clonedRegValue = mapping.lookup(originalRegValue);
585 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
586 auto reg = cast<seq::CompRegOp>(clonedRegOp);
587 const auto res = variableOp.getResult();
588 clonedRegValue.replaceAllUsesWith(res);
589 reg.erase();
590 }
591 for (auto const &[originalRegValue, constStateValue] : stateMap) {
592 // Find the cloned register's result value using the mapping.
593 Value clonedRegValue = mapping.lookup(originalRegValue);
594 assert(clonedRegValue &&
595 "Original register value not found in mapping");
596 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
597
598 assert(clonedRegOp && "Cloned value must have a defining op");
599 opBuilder.setInsertionPoint(clonedRegOp);
600 auto r = cast<seq::CompRegOp>(clonedRegOp);
601 TypedValue<IntegerType> registerReset = r.getReset();
602 if (registerReset) {
603 if (BlockArgument blockArg = dyn_cast<BlockArgument>(registerReset)) {
604 asyncResetArguments.insert(blockArg.getArgNumber());
605 auto falseConst = hw::ConstantOp::create(
606 opBuilder, blockArg.getLoc(), clonedRegValue.getType(), 0);
607 blockArg.replaceAllUsesWith(falseConst.getResult());
608 }
609 if (auto xorOp = registerReset.getDefiningOp<XorOp>()) {
610 if (xorOp.isBinaryNot()) {
611 Value rhs = xorOp.getOperand(0);
612 if (BlockArgument blockArg = dyn_cast<BlockArgument>(rhs)) {
613 asyncResetArguments.insert(blockArg.getArgNumber());
614 auto trueConst = hw::ConstantOp::create(
615 opBuilder, blockArg.getLoc(), blockArg.getType(), 1);
616 blockArg.replaceAllUsesWith(trueConst.getResult());
617 }
618 }
619 }
620 }
621 auto constantOp =
622 hw::ConstantOp::create(opBuilder, clonedRegValue.getLoc(),
623 clonedRegValue.getType(), constStateValue);
624 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
625 clonedRegOp->erase();
626 }
627 GreedyRewriteConfig config;
628 SmallVector<Operation *> opsToProcess;
629 outputRegion.walk([&](Operation *op) { opsToProcess.push_back(op); });
630 // Replace references to arguments in the output block with
631 // arguments at the top level.
632 for (auto arg : outputRegion.front().getArguments()) {
633 int argIndex = arg.getArgNumber();
634 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
635 arg.replaceAllUsesWith(topLevelArg);
636 }
637 outputRegion.front().eraseArguments(
638 [](BlockArgument arg) { return true; });
639 FrozenRewritePatternSet patterns(opBuilder.getContext());
640 config.setScope(&outputRegion);
641
642 bool changed = false;
643 if (failed(applyOpPatternsGreedily(opsToProcess, patterns, config,
644 &changed)))
645 return failure();
646 opBuilder.setInsertionPoint(stateOp);
647 // hw.module uses graph regions that allow cycles (e.g., registers feeding
648 // back into themselves). By this point we've replaced all registers with
649 // constants, but cycles in purely combinational logic (e.g., cyclic
650 // muxes) may still exist. Such cycles cannot be converted to FSM.
651 bool sorted = sortTopologically(&outputRegion.front());
652 if (!sorted) {
653 moduleOp.emitError()
654 << "cannot convert module with combinational cycles to FSM";
655 return failure();
656 }
657 Region &transitionRegion = stateOp.getTransitions();
658 llvm::SetVector<size_t> visitableStates;
659 if (failed(getReachableStates(visitableStates, moduleOp,
660 currentStateIndex, registers, opBuilder,
661 currentStateIndex == initialStateIndex)))
662 return failure();
663 for (size_t j : visitableStates) {
664 StateOp toState;
665 if (!stateToStateOp.contains(j)) {
666 opBuilder.setInsertionPointToEnd(&machine.getBody().front());
667 toState =
668 StateOp::create(opBuilder, loc, "state_" + std::to_string(j));
669 stateToStateOp.insert({j, toState});
670 stateOpToState.insert({toState, j});
671 } else {
672 toState = stateToStateOp[j];
673 }
674 opBuilder.setInsertionPointToStart(&transitionRegion.front());
675 auto transitionOp =
676 TransitionOp::create(opBuilder, loc, "state_" + std::to_string(j));
677 Region &guardRegion = transitionOp.getGuard();
678 opBuilder.createBlock(&guardRegion);
679
680 Block &guardBlock = guardRegion.front();
681
682 opBuilder.setInsertionPointToStart(&guardBlock);
683 IRMapping mapping;
684 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), guardRegion,
685 guardBlock.getIterator(), mapping);
686 guardBlock.erase();
687 Block &newGuardBlock = guardRegion.front();
688 Operation *terminator = newGuardBlock.getTerminator();
689 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
690 assert(hwOutputOp && "Expected terminator to be hw.output op");
691
692 llvm::MapVector<Value, int> toStateMap = intToRegMap(registers, j);
693 SmallVector<Value> equalityChecks;
694 for (auto &[originalRegValue, variableOp] : variableMap) {
695 opBuilder.setInsertionPointToStart(&newGuardBlock);
696 Value clonedRegValue = mapping.lookup(originalRegValue);
697 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
698 auto reg = cast<seq::CompRegOp>(clonedRegOp);
699 const auto res = variableOp.getResult();
700 clonedRegValue.replaceAllUsesWith(res);
701 reg.erase();
702 }
703 for (auto const &[originalRegValue, constStateValue] : toStateMap) {
704
705 Value clonedRegValue = mapping.lookup(originalRegValue);
706 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
707 opBuilder.setInsertionPoint(clonedRegOp);
708 auto r = cast<seq::CompRegOp>(clonedRegOp);
709
710 Value registerInput = r.getInput();
711 TypedValue<IntegerType> registerReset = r.getReset();
712 if (registerReset) {
713 if (BlockArgument blockArg =
714 dyn_cast<BlockArgument>(registerReset)) {
715 auto falseConst = hw::ConstantOp::create(
716 opBuilder, blockArg.getLoc(), clonedRegValue.getType(), 0);
717 blockArg.replaceAllUsesWith(falseConst.getResult());
718 }
719 if (auto xorOp = registerReset.getDefiningOp<XorOp>()) {
720 if (xorOp.isBinaryNot()) {
721 Value rhs = xorOp.getOperand(0);
722 if (BlockArgument blockArg = dyn_cast<BlockArgument>(rhs)) {
723 auto trueConst = hw::ConstantOp::create(
724 opBuilder, blockArg.getLoc(), blockArg.getType(), 1);
725 blockArg.replaceAllUsesWith(trueConst.getResult());
726 }
727 }
728 }
729 }
730 Type constantType = registerInput.getType();
731 IntegerAttr constantAttr =
732 opBuilder.getIntegerAttr(constantType, constStateValue);
733 auto otherStateConstant = hw::ConstantOp::create(
734 opBuilder, hwOutputOp.getLoc(), constantAttr);
735
736 auto doesEqual =
737 ICmpOp::create(opBuilder, hwOutputOp.getLoc(), ICmpPredicate::eq,
738 registerInput, otherStateConstant.getResult());
739 equalityChecks.push_back(doesEqual.getResult());
740 }
741 opBuilder.setInsertionPoint(hwOutputOp);
742 auto allEqualCheck = AndOp::create(opBuilder, hwOutputOp.getLoc(),
743 equalityChecks, false);
744 fsm::ReturnOp::create(opBuilder, hwOutputOp.getLoc(),
745 allEqualCheck.getResult());
746 hwOutputOp.erase();
747 for (BlockArgument arg : newGuardBlock.getArguments()) {
748 int argIndex = arg.getArgNumber();
749 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
750 arg.replaceAllUsesWith(topLevelArg);
751 }
752 newGuardBlock.eraseArguments([](BlockArgument arg) { return true; });
753 llvm::MapVector<Value, int> fromStateMap =
754 intToRegMap(registers, currentStateIndex);
755 for (auto const &[originalRegValue, constStateValue] : fromStateMap) {
756 Value clonedRegValue = mapping.lookup(originalRegValue);
757 assert(clonedRegValue &&
758 "Original register value not found in mapping");
759 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
760 assert(clonedRegOp && "Cloned value must have a defining op");
761 opBuilder.setInsertionPoint(clonedRegOp);
762 auto constantOp =
763 hw::ConstantOp::create(opBuilder, clonedRegValue.getLoc(),
764 clonedRegValue.getType(), constStateValue);
765 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
766 clonedRegOp->erase();
767 }
768 Region &actionRegion = transitionOp.getAction();
769 if (!variableRegs.empty()) {
770 Block *actionBlock = opBuilder.createBlock(&actionRegion);
771 opBuilder.setInsertionPointToStart(actionBlock);
772 IRMapping mapping;
773 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), actionRegion,
774 actionBlock->getIterator(), mapping);
775 actionBlock->erase();
776 Block &newActionBlock = actionRegion.front();
777 for (BlockArgument arg : newActionBlock.getArguments()) {
778 int argIndex = arg.getArgNumber();
779 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
780 arg.replaceAllUsesWith(topLevelArg);
781 }
782 newActionBlock.eraseArguments([](BlockArgument arg) { return true; });
783 for (auto &[originalRegValue, variableOp] : variableMap) {
784 Value clonedRegValue = mapping.lookup(originalRegValue);
785 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
786 auto reg = cast<seq::CompRegOp>(clonedRegOp);
787 opBuilder.setInsertionPointToStart(&newActionBlock);
788 UpdateOp::create(opBuilder, reg.getLoc(), variableOp,
789 reg.getInput());
790 const Value res = variableOp.getResult();
791 clonedRegValue.replaceAllUsesWith(res);
792 reg.erase();
793 }
794 Operation *terminator = actionRegion.back().getTerminator();
795 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
796 assert(hwOutputOp && "Expected terminator to be hw.output op");
797 hwOutputOp.erase();
798
799 for (auto const &[originalRegValue, constStateValue] : fromStateMap) {
800 Value clonedRegValue = mapping.lookup(originalRegValue);
801 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
802 opBuilder.setInsertionPoint(clonedRegOp);
803 auto constantOp = hw::ConstantOp::create(
804 opBuilder, clonedRegValue.getLoc(), clonedRegValue.getType(),
805 constStateValue);
806 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
807 clonedRegOp->erase();
808 }
809
810 FrozenRewritePatternSet patterns(opBuilder.getContext());
811 GreedyRewriteConfig config;
812 SmallVector<Operation *> opsToProcess;
813 actionRegion.walk([&](Operation *op) { opsToProcess.push_back(op); });
814 config.setScope(&actionRegion);
815
816 bool changed = false;
817 if (failed(applyOpPatternsGreedily(opsToProcess, patterns, config,
818 &changed)))
819 return failure();
820
821 // hw.module uses graph regions that allow cycles. By this point
822 // we've replaced all registers with constants, but cycles in purely
823 // combinational logic may still exist.
824 bool actionSorted = sortTopologically(&actionRegion.front());
825 if (!actionSorted) {
826 moduleOp.emitError()
827 << "cannot convert module with combinational cycles to FSM";
828 return failure();
829 }
830 }
831
832 // hw.module uses graph regions that allow cycles. By this point
833 // we've replaced all registers with constants, but cycles in purely
834 // combinational logic may still exist.
835 bool guardSorted = sortTopologically(&newGuardBlock);
836 if (!guardSorted) {
837 moduleOp.emitError()
838 << "cannot convert module with combinational cycles to FSM";
839 return failure();
840 }
841 SmallVector<Operation *> outputOps;
842 stateOp.getOutput().walk(
843 [&](Operation *op) { outputOps.push_back(op); });
844
845 bool changed = false;
846 GreedyRewriteConfig config;
847 config.setScope(&stateOp.getOutput());
848 LogicalResult converged = applyOpPatternsGreedily(
849 outputOps, frozenPatterns, config, &changed);
850 assert(succeeded(converged) && "canonicalization failed to converge");
851 SmallVector<Operation *> transitionOps;
852 stateOp.getTransitions().walk(
853 [&](Operation *op) { transitionOps.push_back(op); });
854
855 GreedyRewriteConfig config2;
856 config2.setScope(&stateOp.getTransitions());
857 if (failed(applyOpPatternsGreedily(transitionOps, frozenPatterns,
858 config2, &changed))) {
859 return failure();
860 }
861
862 for (TransitionOp transition :
863 stateOp.getTransitions().getOps<TransitionOp>()) {
864 StateOp nextState = transition.getNextStateOp();
865 int nextStateIndex = stateOpToState.lookup(nextState);
866 auto guardConst = transition.getGuardReturn()
867 .getOperand()
868 .getDefiningOp<hw::ConstantOp>();
869 bool nextStateIsReachable =
870 !guardConst || (guardConst.getValueAttr().getInt() != 0);
871 // If we find a valid next state and haven't seen it before, add it to
872 // the worklist and the set of reachable states.
873 if (nextStateIsReachable &&
874 !reachableStates.contains(nextStateIndex)) {
875 worklist.push_back(nextStateIndex);
876 reachableStates.insert(nextStateIndex);
877 }
878 }
879 }
880 }
881
882 // Clean up unreachable states. States without an output region are
883 // placeholder states that were created during reachability analysis but
884 // never populated (i.e., they are unreachable from the initial state).
885 SmallVector<StateOp> statesToErase;
886
887 // Collect unreachable states (those without an output op).
888 for (StateOp stateOp : machine.getOps<StateOp>()) {
889 if (!stateOp.getOutputOp()) {
890 statesToErase.push_back(stateOp);
891 }
892 }
893
894 // Erase states in a separate loop to avoid iterator invalidation. We first
895 // collect all states to erase, then iterate over that list. This is
896 // necessary because erasing a state while iterating over machine.getOps()
897 // would invalidate the iterator.
898 for (StateOp stateOp : statesToErase) {
899 for (TransitionOp transition : machine.getOps<TransitionOp>()) {
900 if (transition.getNextStateOp().getSymName() == stateOp.getSymName()) {
901 transition.erase();
902 }
903 }
904 stateOp.erase();
905 }
906
907 llvm::DenseSet<BlockArgument> asyncResetBlockArguments;
908 for (auto arg : machine.getBody().front().getArguments()) {
909 if (asyncResetArguments.contains(arg.getArgNumber())) {
910 asyncResetBlockArguments.insert(arg);
911 }
912 }
913
914 // Emit a warning if reset signals were detected and removed.
915 // The FSM dialect does not support reset signals, so the reset behavior
916 // is only captured in the initial state. The original reset triggering
917 // mechanism is not preserved.
918 if (!asyncResetBlockArguments.empty()) {
919 moduleOp.emitWarning()
920 << "reset signals detected and removed from FSM; "
921 "reset behavior is captured only in the initial state";
922 }
923
924 Block &front = machine.getBody().front();
925 front.eraseArguments([&](BlockArgument arg) {
926 return asyncResetBlockArguments.contains(arg);
927 });
928 machine.getBody().front().eraseArguments([&](BlockArgument arg) {
929 return arg.getType() == seq::ClockType::get(arg.getContext());
930 });
931 FunctionType oldFunctionType = machine.getFunctionType();
932 SmallVector<Type> inputsWithoutClock;
933 for (unsigned int i = 0; i < oldFunctionType.getNumInputs(); i++) {
934 Type input = oldFunctionType.getInput(i);
935 if (input != seq::ClockType::get(input.getContext()) &&
936 !asyncResetArguments.contains(i))
937 inputsWithoutClock.push_back(input);
938 }
939
940 FunctionType newFunctionType = FunctionType::get(
941 opBuilder.getContext(), inputsWithoutClock, resultTypes);
942
943 machine.setFunctionType(newFunctionType);
944 moduleOp.erase();
945 return success();
946 }
947
948private:
949 /// Helper function to determine if a register is a state register.
950 bool isStateRegister(seq::CompRegOp reg) const {
951 auto regName = reg.getName();
952 if (!regName)
953 return false;
954
955 // If user specified state registers, check if this register's name matches
956 // any of them.
957 if (!stateRegNames.empty()) {
958 return llvm::is_contained(stateRegNames, regName->str());
959 }
960
961 // Default behavior: infer state registers by checking if the name contains
962 // "state".
963 return regName->contains("state");
964 }
965
966 HWModuleOp moduleOp;
967 OpBuilder &opBuilder;
968 ArrayRef<std::string> stateRegNames;
969};
970
971} // namespace
972
973namespace {
974struct CoreToFSMPass : public circt::impl::ConvertCoreToFSMBase<CoreToFSMPass> {
975 using ConvertCoreToFSMBase<CoreToFSMPass>::ConvertCoreToFSMBase;
976
977 void runOnOperation() override {
978 auto module = getOperation();
979 OpBuilder builder(module);
980
981 SmallVector<HWModuleOp> modules;
982 for (auto hwModule : module.getOps<HWModuleOp>()) {
983 modules.push_back(hwModule);
984 }
985
986 // Check for hw.instance operations - instance conversion is not supported.
987 for (auto hwModule : modules) {
988 for (auto instance : hwModule.getOps<hw::InstanceOp>()) {
989 instance.emitError() << "instance conversion is not yet supported";
990 signalPassFailure();
991 return;
992 }
993 }
994
995 for (auto hwModule : modules) {
996 builder.setInsertionPoint(hwModule);
997 HWModuleOpConverter converter(builder, hwModule, stateRegs);
998 if (failed(converter.run())) {
999 signalPassFailure();
1000 return;
1001 }
1002 }
1003 }
1004};
1005} // namespace
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
create(*operands)
Definition fsm.py:160
create(name)
Definition fsm.py:144
create(to_state)
Definition fsm.py:110
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition comb.py:1
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition codegen.py:121
Definition fsm.py:1
Definition hw.py:1
Definition seq.py:1
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21