18#include "mlir/Analysis/TopologicalSortUtils.h"
19#include "mlir/IR/Block.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinAttributes.h"
22#include "mlir/IR/Diagnostics.h"
23#include "mlir/IR/IRMapping.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/IR/Value.h"
26#include "mlir/Pass/Pass.h"
27#include "mlir/Pass/PassManager.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#include "mlir/Transforms/Passes.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/DenseSet.h"
33#include "llvm/ADT/SmallVector.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/LogicalResult.h"
39#define GEN_PASS_DEF_CONVERTCORETOFSM
40#include "circt/Conversion/Passes.h.inc"
52static void generateConcatenatedValues(
53 const llvm::SmallVector<llvm::SetVector<size_t>> &allOperandValues,
54 const llvm::SmallVector<unsigned> &shifts,
55 llvm::SetVector<size_t> &finalPossibleValues);
58static void addPossibleValuesImpl(llvm::SetVector<size_t> &possibleValues,
59 Value v, llvm::DenseSet<Value> &visited) {
61 if (!visited.insert(v).second)
64 if (
auto c = dyn_cast_or_null<hw::ConstantOp>(v.getDefiningOp())) {
65 possibleValues.insert(c.getValueAttr().getValue().getZExtValue());
68 if (
auto m = dyn_cast_or_null<MuxOp>(v.getDefiningOp())) {
69 addPossibleValuesImpl(possibleValues, m.getTrueValue(), visited);
70 addPossibleValuesImpl(possibleValues, m.getFalseValue(), visited);
74 if (
auto concatOp = dyn_cast_or_null<ConcatOp>(v.getDefiningOp())) {
75 llvm::SmallVector<llvm::SetVector<size_t>> allOperandValues;
76 llvm::SmallVector<unsigned> operandWidths;
78 for (Value operand : concatOp.getOperands()) {
79 llvm::SetVector<size_t> operandPossibleValues;
80 addPossibleValuesImpl(operandPossibleValues, operand, visited);
85 auto opType = dyn_cast<IntegerType>(operand.getType());
88 assert(opType &&
"comb.concat operand must be an integer type");
89 unsigned width = opType.getWidth();
90 if (operandPossibleValues.empty()) {
91 uint64_t numStates = 1ULL << width;
94 if (numStates > 256) {
98 v.getDefiningOp()->emitWarning()
99 <<
"Search space too large (>" << 256
100 <<
" states) for operand with bitwidth " << width
101 <<
"; abandoning analysis for this path";
104 for (uint64_t i = 0; i < numStates; ++i)
105 operandPossibleValues.insert(i);
108 allOperandValues.push_back(operandPossibleValues);
109 operandWidths.push_back(width);
114 llvm::SmallVector<unsigned> shifts(concatOp.getNumOperands(), 0);
115 for (
int i = concatOp.getNumOperands() - 2; i >= 0; --i) {
116 shifts[i] = shifts[i + 1] + operandWidths[i + 1];
119 generateConcatenatedValues(allOperandValues, shifts, possibleValues);
127 auto addrType = dyn_cast<IntegerType>(v.getType());
131 unsigned bitWidth = addrType.getWidth();
134 if (v.getDefiningOp())
135 v.getDefiningOp()->emitWarning()
136 <<
"Bitwidth " << bitWidth
137 <<
" too large (>16); abandoning analysis for this path";
141 uint64_t numRegStates = 1ULL << bitWidth;
142 for (
size_t i = 0; i < numRegStates; i++) {
143 possibleValues.insert(i);
147static void addPossibleValues(llvm::SetVector<size_t> &possibleValues,
149 llvm::DenseSet<Value> visited;
150 addPossibleValuesImpl(possibleValues, v, visited);
155static bool isConstantOrConstantTree(Value value) {
156 SmallVector<Value> worklist;
157 llvm::DenseSet<Value> visited;
159 worklist.push_back(value);
160 while (!worklist.empty()) {
161 Value current = worklist.pop_back_val();
164 if (!visited.insert(current).second)
167 Operation *definingOp = current.getDefiningOp();
171 if (isa<hw::ConstantOp>(definingOp))
174 if (
auto muxOp = dyn_cast<MuxOp>(definingOp)) {
175 worklist.push_back(muxOp.getTrueValue());
176 worklist.push_back(muxOp.getFalseValue());
191LogicalResult pushIcmp(ICmpOp op, PatternRewriter &rewriter) {
193 if (op.getPredicate() == ICmpPredicate::eq &&
194 op.getLhs().getDefiningOp<
MuxOp>() &&
195 (isConstantOrConstantTree(op.getLhs()) ||
197 rewriter.setInsertionPointAfter(op);
198 auto mux = op.getLhs().getDefiningOp<
MuxOp>();
199 Value x = mux.getTrueValue();
200 Value y = mux.getFalseValue();
201 Value b = op.getRhs();
202 Location loc = op.getLoc();
203 auto eq1 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, x, b);
204 auto eq2 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, y, b);
205 rewriter.replaceOpWithNewOp<
MuxOp>(op, mux.getCond(), eq1.getResult(),
207 return llvm::success();
209 if (op.getPredicate() == ICmpPredicate::eq &&
210 op.getRhs().getDefiningOp<
MuxOp>() &&
211 (isConstantOrConstantTree(op.getRhs()) ||
213 rewriter.setInsertionPointAfter(op);
214 auto mux = op.getRhs().getDefiningOp<
MuxOp>();
215 Value x = mux.getTrueValue();
216 Value y = mux.getFalseValue();
217 Value b = op.getLhs();
218 Location loc = op.getLoc();
219 auto eq1 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, x, b);
220 auto eq2 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, y, b);
221 rewriter.replaceOpWithNewOp<
MuxOp>(op, mux.getCond(), eq1.getResult(),
223 return llvm::success();
225 return llvm::failure();
230static void generateConcatenatedValues(
231 const llvm::SmallVector<llvm::SetVector<size_t>> &allOperandValues,
232 const llvm::SmallVector<unsigned> &shifts,
233 llvm::SetVector<size_t> &finalPossibleValues) {
235 if (allOperandValues.empty()) {
236 finalPossibleValues.insert(0);
241 llvm::SetVector<size_t> currentResults;
242 for (
size_t val : allOperandValues[0])
243 currentResults.insert(val << shifts[0]);
246 for (
size_t operandIdx = 1; operandIdx < allOperandValues.size();
248 llvm::SetVector<size_t> nextResults;
249 unsigned shift = shifts[operandIdx];
251 for (
size_t partialValue : currentResults) {
252 for (
size_t val : allOperandValues[operandIdx]) {
253 nextResults.insert(partialValue | (val << shift));
256 currentResults = std::move(nextResults);
259 finalPossibleValues = std::move(currentResults);
262static llvm::MapVector<Value, int> intToRegMap(SmallVector<seq::CompRegOp> v,
264 llvm::MapVector<Value, int> m;
265 for (
size_t ci = 0; ci < v.size(); ci++) {
267 int bits =
reg.getType().getIntOrFloatBitWidth();
268 int v = i & ((1 << bits) - 1);
275static int regMapToInt(SmallVector<seq::CompRegOp> v,
276 llvm::DenseMap<Value, int> m) {
279 for (
size_t ci = 0; ci < v.size(); ci++) {
281 i += m[
reg] * 1ULL << width;
282 width += (
reg.getType().getIntOrFloatBitWidth());
288static std::set<llvm::SmallVector<size_t>> calculateCartesianProduct(
289 const llvm::SmallVector<llvm::SetVector<size_t>> &valueSets) {
290 std::set<llvm::SmallVector<size_t>> product;
291 if (valueSets.empty()) {
300 for (
size_t value : valueSets.front()) {
301 product.insert({value});
307 for (
size_t i = 1; i < valueSets.size(); ++i) {
308 const auto ¤tSet = valueSets[i];
309 if (currentSet.empty()) {
314 std::set<llvm::SmallVector<size_t>> newProduct;
315 for (
const auto &existingVector : product) {
316 for (
size_t newValue : currentSet) {
317 llvm::SmallVector<size_t> newVector = existingVector;
318 newVector.push_back(newValue);
319 newProduct.insert(std::move(newVector));
322 product = std::move(newProduct);
328static FrozenRewritePatternSet loadPatterns(MLIRContext &
context) {
331 for (
auto *dialect :
context.getLoadedDialects())
332 dialect->getCanonicalizationPatterns(
patterns);
347 FrozenRewritePatternSet frozenPatterns(std::move(
patterns));
348 return frozenPatterns;
352getReachableStates(llvm::SetVector<size_t> &visitableStates,
353 HWModuleOp moduleOp,
size_t currentStateIndex,
354 SmallVector<seq::CompRegOp> registers, OpBuilder opBuilder,
355 bool isInitialState) {
359 llvm::dyn_cast<HWModuleOp>(opBuilder.clone(*moduleOp, mapping));
361 llvm::MapVector<Value, int> stateMap =
362 intToRegMap(registers, currentStateIndex);
363 Operation *terminator = clonedBody.getBody().front().getTerminator();
364 auto output = dyn_cast<hw::OutputOp>(terminator);
365 SmallVector<Value> values;
367 for (
auto [originalRegValue, constStateValue] : stateMap) {
369 Value clonedRegValue = mapping.lookup(originalRegValue);
370 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
371 auto reg = cast<seq::CompRegOp>(clonedRegOp);
372 Type constantType =
reg.getType();
373 IntegerAttr constantAttr =
374 opBuilder.getIntegerAttr(constantType, constStateValue);
375 opBuilder.setInsertionPoint(clonedRegOp);
376 auto otherStateConstant =
381 Value regInput =
reg.getInput();
382 if (regInput == clonedRegValue)
383 values.push_back(otherStateConstant.getResult());
385 values.push_back(regInput);
386 clonedRegValue.replaceAllUsesWith(otherStateConstant.getResult());
389 opBuilder.setInsertionPointToEnd(clonedBody.front().getBlock());
390 auto newOutput = hw::OutputOp::create(opBuilder, output.getLoc(), values);
392 FrozenRewritePatternSet frozenPatterns = loadPatterns(*moduleOp.getContext());
394 SmallVector<Operation *> opsToProcess;
395 clonedBody.walk([&](Operation *op) { opsToProcess.push_back(op); });
397 bool changed =
false;
398 GreedyRewriteConfig config;
399 if (failed(applyOpPatternsGreedily(opsToProcess, frozenPatterns, config,
403 llvm::SmallVector<llvm::SetVector<size_t>> pv;
404 for (
size_t j = 0; j < newOutput.getNumOperands(); j++) {
405 llvm::SetVector<size_t> possibleValues;
407 Value v = newOutput.getOperand(j);
408 addPossibleValues(possibleValues, v);
409 pv.push_back(possibleValues);
411 std::set<llvm::SmallVector<size_t>> flipped = calculateCartesianProduct(pv);
412 for (llvm::SmallVector<size_t> v : flipped) {
413 llvm::DenseMap<Value, int> m;
414 for (
size_t k = 0; k < v.size(); k++) {
419 int i = regMapToInt(registers, m);
420 visitableStates.insert(i);
428class HWModuleOpConverter {
430 HWModuleOpConverter(OpBuilder &builder,
HWModuleOp moduleOp,
431 ArrayRef<std::string> stateRegNames)
432 : moduleOp(moduleOp), opBuilder(builder), stateRegNames(stateRegNames) {}
433 LogicalResult
run() {
434 SmallVector<seq::CompRegOp> stateRegs;
435 SmallVector<seq::CompRegOp> variableRegs;
438 if (!isa<IntegerType>(
reg.getType())) {
439 reg.emitError(
"FSM extraction only supports integer-typed registers");
440 return WalkResult::interrupt();
442 if (isStateRegister(reg)) {
443 stateRegs.push_back(reg);
445 variableRegs.push_back(reg);
447 return WalkResult::advance();
449 if (walkResult.wasInterrupted())
451 if (stateRegs.empty()) {
452 emitError(moduleOp.getLoc())
453 <<
"Cannot find state register in this FSM. Use the state-regs "
454 "option to specify which registers are state registers.";
457 SmallVector<seq::CompRegOp> registers;
459 registers.push_back(c);
462 llvm::DenseMap<size_t, StateOp> stateToStateOp;
463 llvm::DenseMap<StateOp, size_t> stateOpToState;
468 llvm::DenseSet<size_t> asyncResetArguments;
469 Location loc = moduleOp.getLoc();
470 SmallVector<Type> inputTypes = moduleOp.getInputTypes();
473 auto resultTypes = moduleOp.getOutputTypes();
474 FunctionType machineType =
475 FunctionType::get(opBuilder.getContext(), inputTypes, resultTypes);
476 StringRef machineName = moduleOp.getName();
478 llvm::DenseMap<Value, int> initialStateMap;
480 Value resetValue =
reg.getResetValue();
481 auto definingConstant = resetValue.getDefiningOp<
hw::ConstantOp>();
482 if (!definingConstant) {
484 "cannot find defining constant for reset value of register");
488 definingConstant.getValueAttr().getValue().getZExtValue();
489 initialStateMap[
reg] = resetValueInt;
491 int initialStateIndex = regMapToInt(registers, initialStateMap);
493 std::string initialStateName =
"state_" + std::to_string(initialStateIndex);
496 SmallVector<NamedAttribute> machineAttrs;
497 if (
auto argNames = moduleOp->getAttrOfType<ArrayAttr>(
"argNames"))
498 machineAttrs.emplace_back(opBuilder.getStringAttr(
"argNames"), argNames);
499 if (
auto resNames = moduleOp->getAttrOfType<ArrayAttr>(
"resultNames"))
500 machineAttrs.emplace_back(opBuilder.getStringAttr(
"resNames"), resNames);
504 opBuilder.setInsertionPoint(moduleOp);
506 MachineOp::create(opBuilder, loc, machineName, initialStateName,
507 machineType, machineAttrs);
509 OpBuilder::InsertionGuard guard(opBuilder);
510 opBuilder.setInsertionPointToStart(&machine.getBody().front());
511 llvm::MapVector<seq::CompRegOp, VariableOp> variableMap;
513 TypedValue<Type> initialValue = varReg.getResetValue();
514 auto definingConstant = initialValue.getDefiningOp<
hw::ConstantOp>();
515 if (!definingConstant) {
516 varReg->emitError(
"cannot find defining constant for reset value of "
517 "variable register");
520 auto variableOp = VariableOp::create(
521 opBuilder, varReg->getLoc(), varReg.getInput().getType(),
522 definingConstant.getValueAttr(), varReg.getName().value_or(
"var"));
523 variableMap[varReg] = variableOp;
527 FrozenRewritePatternSet frozenPatterns =
528 loadPatterns(*moduleOp.getContext());
530 SetVector<int> reachableStates;
531 SmallVector<int> worklist;
533 worklist.push_back(initialStateIndex);
534 reachableStates.insert(initialStateIndex);
537 for (
unsigned i = 0; i < worklist.size(); ++i) {
539 int currentStateIndex = worklist[i];
541 llvm::MapVector<Value, int> stateMap =
542 intToRegMap(registers, currentStateIndex);
544 opBuilder.setInsertionPointToEnd(&machine.getBody().front());
548 if (!stateToStateOp.contains(currentStateIndex)) {
550 "state_" + std::to_string(currentStateIndex));
551 stateToStateOp.insert({currentStateIndex, stateOp});
552 stateOpToState.insert({stateOp, currentStateIndex});
554 stateOp = stateToStateOp.lookup(currentStateIndex);
556 Region &outputRegion = stateOp.getOutput();
557 Block *outputBlock = &outputRegion.front();
558 opBuilder.setInsertionPointToStart(outputBlock);
560 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), outputRegion,
561 outputBlock->getIterator(), mapping);
562 outputBlock->erase();
564 auto *terminator = outputRegion.front().getTerminator();
565 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
566 assert(hwOutputOp &&
"Expected terminator to be hw.output op");
570 OpBuilder::InsertionGuard stateGuard(opBuilder);
571 opBuilder.setInsertionPoint(hwOutputOp);
576 hwOutputOp.getOperands());
583 for (
auto &[originalRegValue, variableOp] : variableMap) {
584 Value clonedRegValue = mapping.lookup(originalRegValue);
585 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
586 auto reg = cast<seq::CompRegOp>(clonedRegOp);
587 const auto res = variableOp.getResult();
588 clonedRegValue.replaceAllUsesWith(res);
591 for (
auto const &[originalRegValue, constStateValue] : stateMap) {
593 Value clonedRegValue = mapping.lookup(originalRegValue);
595 "Original register value not found in mapping");
596 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
598 assert(clonedRegOp &&
"Cloned value must have a defining op");
599 opBuilder.setInsertionPoint(clonedRegOp);
600 auto r = cast<seq::CompRegOp>(clonedRegOp);
601 TypedValue<IntegerType> registerReset = r.getReset();
603 if (BlockArgument blockArg = dyn_cast<BlockArgument>(registerReset)) {
604 asyncResetArguments.insert(blockArg.getArgNumber());
606 opBuilder, blockArg.getLoc(), clonedRegValue.getType(), 0);
607 blockArg.replaceAllUsesWith(falseConst.getResult());
609 if (
auto xorOp = registerReset.getDefiningOp<
XorOp>()) {
610 if (xorOp.isBinaryNot()) {
611 Value rhs = xorOp.getOperand(0);
612 if (BlockArgument blockArg = dyn_cast<BlockArgument>(rhs)) {
613 asyncResetArguments.insert(blockArg.getArgNumber());
615 opBuilder, blockArg.getLoc(), blockArg.getType(), 1);
616 blockArg.replaceAllUsesWith(trueConst.getResult());
623 clonedRegValue.getType(), constStateValue);
624 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
625 clonedRegOp->erase();
627 GreedyRewriteConfig config;
628 SmallVector<Operation *> opsToProcess;
629 outputRegion.walk([&](Operation *op) { opsToProcess.push_back(op); });
632 for (
auto arg : outputRegion.front().getArguments()) {
633 int argIndex = arg.getArgNumber();
634 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
635 arg.replaceAllUsesWith(topLevelArg);
637 outputRegion.front().eraseArguments(
638 [](BlockArgument arg) {
return true; });
639 FrozenRewritePatternSet
patterns(opBuilder.getContext());
640 config.setScope(&outputRegion);
642 bool changed =
false;
643 if (failed(applyOpPatternsGreedily(opsToProcess,
patterns, config,
646 opBuilder.setInsertionPoint(stateOp);
651 bool sorted = sortTopologically(&outputRegion.front());
654 <<
"cannot convert module with combinational cycles to FSM";
657 Region &transitionRegion = stateOp.getTransitions();
658 llvm::SetVector<size_t> visitableStates;
659 if (failed(getReachableStates(visitableStates, moduleOp,
660 currentStateIndex, registers, opBuilder,
661 currentStateIndex == initialStateIndex)))
663 for (
size_t j : visitableStates) {
665 if (!stateToStateOp.contains(j)) {
666 opBuilder.setInsertionPointToEnd(&machine.getBody().front());
669 stateToStateOp.insert({j, toState});
670 stateOpToState.insert({toState, j});
672 toState = stateToStateOp[j];
674 opBuilder.setInsertionPointToStart(&transitionRegion.front());
677 Region &guardRegion = transitionOp.getGuard();
678 opBuilder.createBlock(&guardRegion);
680 Block &guardBlock = guardRegion.front();
682 opBuilder.setInsertionPointToStart(&guardBlock);
684 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), guardRegion,
685 guardBlock.getIterator(), mapping);
687 Block &newGuardBlock = guardRegion.front();
688 Operation *terminator = newGuardBlock.getTerminator();
689 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
690 assert(hwOutputOp &&
"Expected terminator to be hw.output op");
692 llvm::MapVector<Value, int> toStateMap = intToRegMap(registers, j);
693 SmallVector<Value> equalityChecks;
694 for (
auto &[originalRegValue, variableOp] : variableMap) {
695 opBuilder.setInsertionPointToStart(&newGuardBlock);
696 Value clonedRegValue = mapping.lookup(originalRegValue);
697 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
698 auto reg = cast<seq::CompRegOp>(clonedRegOp);
699 const auto res = variableOp.getResult();
700 clonedRegValue.replaceAllUsesWith(res);
703 for (
auto const &[originalRegValue, constStateValue] : toStateMap) {
705 Value clonedRegValue = mapping.lookup(originalRegValue);
706 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
707 opBuilder.setInsertionPoint(clonedRegOp);
708 auto r = cast<seq::CompRegOp>(clonedRegOp);
710 Value registerInput = r.getInput();
711 TypedValue<IntegerType> registerReset = r.getReset();
713 if (BlockArgument blockArg =
714 dyn_cast<BlockArgument>(registerReset)) {
716 opBuilder, blockArg.getLoc(), clonedRegValue.getType(), 0);
717 blockArg.replaceAllUsesWith(falseConst.getResult());
719 if (
auto xorOp = registerReset.getDefiningOp<
XorOp>()) {
720 if (xorOp.isBinaryNot()) {
721 Value rhs = xorOp.getOperand(0);
722 if (BlockArgument blockArg = dyn_cast<BlockArgument>(rhs)) {
724 opBuilder, blockArg.getLoc(), blockArg.getType(), 1);
725 blockArg.replaceAllUsesWith(trueConst.getResult());
730 Type constantType = registerInput.getType();
731 IntegerAttr constantAttr =
732 opBuilder.getIntegerAttr(constantType, constStateValue);
734 opBuilder, hwOutputOp.getLoc(), constantAttr);
737 ICmpOp::create(opBuilder, hwOutputOp.getLoc(), ICmpPredicate::eq,
738 registerInput, otherStateConstant.getResult());
739 equalityChecks.push_back(doesEqual.getResult());
741 opBuilder.setInsertionPoint(hwOutputOp);
742 auto allEqualCheck = AndOp::create(opBuilder, hwOutputOp.getLoc(),
743 equalityChecks,
false);
744 fsm::ReturnOp::create(opBuilder, hwOutputOp.getLoc(),
745 allEqualCheck.getResult());
747 for (BlockArgument arg : newGuardBlock.getArguments()) {
748 int argIndex = arg.getArgNumber();
749 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
750 arg.replaceAllUsesWith(topLevelArg);
752 newGuardBlock.eraseArguments([](BlockArgument arg) {
return true; });
753 llvm::MapVector<Value, int> fromStateMap =
754 intToRegMap(registers, currentStateIndex);
755 for (
auto const &[originalRegValue, constStateValue] : fromStateMap) {
756 Value clonedRegValue = mapping.lookup(originalRegValue);
758 "Original register value not found in mapping");
759 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
760 assert(clonedRegOp &&
"Cloned value must have a defining op");
761 opBuilder.setInsertionPoint(clonedRegOp);
764 clonedRegValue.getType(), constStateValue);
765 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
766 clonedRegOp->erase();
768 Region &actionRegion = transitionOp.getAction();
769 if (!variableRegs.empty()) {
770 Block *actionBlock = opBuilder.createBlock(&actionRegion);
771 opBuilder.setInsertionPointToStart(actionBlock);
773 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), actionRegion,
774 actionBlock->getIterator(), mapping);
775 actionBlock->erase();
776 Block &newActionBlock = actionRegion.front();
777 for (BlockArgument arg : newActionBlock.getArguments()) {
778 int argIndex = arg.getArgNumber();
779 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
780 arg.replaceAllUsesWith(topLevelArg);
782 newActionBlock.eraseArguments([](BlockArgument arg) {
return true; });
783 for (
auto &[originalRegValue, variableOp] : variableMap) {
784 Value clonedRegValue = mapping.lookup(originalRegValue);
785 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
786 auto reg = cast<seq::CompRegOp>(clonedRegOp);
787 opBuilder.setInsertionPointToStart(&newActionBlock);
788 UpdateOp::create(opBuilder,
reg.getLoc(), variableOp,
790 const Value res = variableOp.getResult();
791 clonedRegValue.replaceAllUsesWith(res);
794 Operation *terminator = actionRegion.back().getTerminator();
795 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
796 assert(hwOutputOp &&
"Expected terminator to be hw.output op");
799 for (
auto const &[originalRegValue, constStateValue] : fromStateMap) {
800 Value clonedRegValue = mapping.lookup(originalRegValue);
801 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
802 opBuilder.setInsertionPoint(clonedRegOp);
804 opBuilder, clonedRegValue.getLoc(), clonedRegValue.getType(),
806 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
807 clonedRegOp->erase();
810 FrozenRewritePatternSet
patterns(opBuilder.getContext());
811 GreedyRewriteConfig config;
812 SmallVector<Operation *> opsToProcess;
813 actionRegion.walk([&](Operation *op) { opsToProcess.push_back(op); });
814 config.setScope(&actionRegion);
816 bool changed =
false;
817 if (failed(applyOpPatternsGreedily(opsToProcess,
patterns, config,
824 bool actionSorted = sortTopologically(&actionRegion.front());
827 <<
"cannot convert module with combinational cycles to FSM";
835 bool guardSorted = sortTopologically(&newGuardBlock);
838 <<
"cannot convert module with combinational cycles to FSM";
841 SmallVector<Operation *> outputOps;
842 stateOp.getOutput().walk(
843 [&](Operation *op) { outputOps.push_back(op); });
845 bool changed =
false;
846 GreedyRewriteConfig config;
847 config.setScope(&stateOp.getOutput());
848 LogicalResult converged = applyOpPatternsGreedily(
849 outputOps, frozenPatterns, config, &changed);
850 assert(succeeded(converged) &&
"canonicalization failed to converge");
851 SmallVector<Operation *> transitionOps;
852 stateOp.getTransitions().walk(
853 [&](Operation *op) { transitionOps.push_back(op); });
855 GreedyRewriteConfig config2;
856 config2.setScope(&stateOp.getTransitions());
857 if (failed(applyOpPatternsGreedily(transitionOps, frozenPatterns,
858 config2, &changed))) {
864 StateOp nextState = transition.getNextStateOp();
865 int nextStateIndex = stateOpToState.lookup(nextState);
866 auto guardConst = transition.getGuardReturn()
869 bool nextStateIsReachable =
870 !guardConst || (guardConst.getValueAttr().getInt() != 0);
873 if (nextStateIsReachable &&
874 !reachableStates.contains(nextStateIndex)) {
875 worklist.push_back(nextStateIndex);
876 reachableStates.insert(nextStateIndex);
885 SmallVector<StateOp> statesToErase;
889 if (!stateOp.getOutputOp()) {
890 statesToErase.push_back(stateOp);
898 for (
StateOp stateOp : statesToErase) {
900 if (transition.getNextStateOp().getSymName() == stateOp.getSymName()) {
907 llvm::DenseSet<BlockArgument> asyncResetBlockArguments;
908 for (
auto arg : machine.getBody().front().getArguments()) {
909 if (asyncResetArguments.contains(arg.getArgNumber())) {
910 asyncResetBlockArguments.insert(arg);
918 if (!asyncResetBlockArguments.empty()) {
919 moduleOp.emitWarning()
920 <<
"reset signals detected and removed from FSM; "
921 "reset behavior is captured only in the initial state";
924 Block &front = machine.getBody().front();
925 front.eraseArguments([&](BlockArgument arg) {
926 return asyncResetBlockArguments.contains(arg);
928 machine.getBody().front().eraseArguments([&](BlockArgument arg) {
929 return arg.getType() == seq::ClockType::get(arg.getContext());
931 FunctionType oldFunctionType = machine.getFunctionType();
932 SmallVector<Type> inputsWithoutClock;
933 for (
unsigned int i = 0; i < oldFunctionType.getNumInputs(); i++) {
934 Type input = oldFunctionType.getInput(i);
935 if (input != seq::ClockType::get(input.getContext()) &&
936 !asyncResetArguments.contains(i))
937 inputsWithoutClock.push_back(input);
940 FunctionType newFunctionType = FunctionType::get(
941 opBuilder.getContext(), inputsWithoutClock, resultTypes);
943 machine.setFunctionType(newFunctionType);
951 auto regName =
reg.getName();
957 if (!stateRegNames.empty()) {
958 return llvm::is_contained(stateRegNames, regName->str());
963 return regName->contains(
"state");
967 OpBuilder &opBuilder;
968 ArrayRef<std::string> stateRegNames;
974struct CoreToFSMPass :
public circt::impl::ConvertCoreToFSMBase<CoreToFSMPass> {
975 using ConvertCoreToFSMBase<CoreToFSMPass>::ConvertCoreToFSMBase;
977 void runOnOperation()
override {
978 auto module = getOperation();
979 OpBuilder builder(module);
981 SmallVector<HWModuleOp> modules;
982 for (
auto hwModule : module.getOps<
HWModuleOp>()) {
983 modules.push_back(hwModule);
987 for (
auto hwModule : modules) {
988 for (
auto instance : hwModule.getOps<
hw::InstanceOp>()) {
989 instance.emitError() <<
"instance conversion is not yet supported";
995 for (
auto hwModule : modules) {
996 builder.setInsertionPoint(hwModule);
997 HWModuleOpConverter converter(builder, hwModule, stateRegs);
998 if (failed(converter.run())) {
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)