18#include "mlir/Analysis/TopologicalSortUtils.h"
19#include "mlir/IR/Block.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinAttributes.h"
22#include "mlir/IR/Diagnostics.h"
23#include "mlir/IR/IRMapping.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/IR/Value.h"
26#include "mlir/Pass/Pass.h"
27#include "mlir/Pass/PassManager.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#include "mlir/Transforms/Passes.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/DenseSet.h"
33#include "llvm/ADT/SmallVector.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/LogicalResult.h"
39#define GEN_PASS_DEF_CONVERTCORETOFSM
40#include "circt/Conversion/Passes.h.inc"
52static void generateConcatenatedValues(
53 const llvm::SmallVector<llvm::SetVector<size_t>> &allOperandValues,
54 const llvm::SmallVector<unsigned> &shifts,
55 llvm::SetVector<size_t> &finalPossibleValues);
58static void addPossibleValuesImpl(llvm::SetVector<size_t> &possibleValues,
59 Value v, llvm::DenseSet<Value> &visited) {
61 if (!visited.insert(v).second)
64 if (
auto c = dyn_cast_or_null<hw::ConstantOp>(v.getDefiningOp())) {
65 possibleValues.insert(c.getValueAttr().getValue().getZExtValue());
68 if (
auto m = dyn_cast_or_null<MuxOp>(v.getDefiningOp())) {
69 addPossibleValuesImpl(possibleValues, m.getTrueValue(), visited);
70 addPossibleValuesImpl(possibleValues, m.getFalseValue(), visited);
74 if (
auto concatOp = dyn_cast_or_null<ConcatOp>(v.getDefiningOp())) {
75 llvm::SmallVector<llvm::SetVector<size_t>> allOperandValues;
76 llvm::SmallVector<unsigned> operandWidths;
78 for (Value operand : concatOp.getOperands()) {
79 llvm::SetVector<size_t> operandPossibleValues;
80 addPossibleValuesImpl(operandPossibleValues, operand, visited);
85 auto opType = dyn_cast<IntegerType>(operand.getType());
88 assert(opType &&
"comb.concat operand must be an integer type");
89 unsigned width = opType.getWidth();
90 if (operandPossibleValues.empty()) {
91 uint64_t numStates = 1ULL << width;
94 if (numStates > 256) {
98 v.getDefiningOp()->emitWarning()
99 <<
"Search space too large (>" << 256
100 <<
" states) for operand with bitwidth " << width
101 <<
"; abandoning analysis for this path";
104 for (uint64_t i = 0; i < numStates; ++i)
105 operandPossibleValues.insert(i);
108 allOperandValues.push_back(operandPossibleValues);
109 operandWidths.push_back(width);
114 llvm::SmallVector<unsigned> shifts(concatOp.getNumOperands(), 0);
115 for (
int i = concatOp.getNumOperands() - 2; i >= 0; --i) {
116 shifts[i] = shifts[i + 1] + operandWidths[i + 1];
119 generateConcatenatedValues(allOperandValues, shifts, possibleValues);
127 auto addrType = dyn_cast<IntegerType>(v.getType());
131 unsigned bitWidth = addrType.getWidth();
134 if (v.getDefiningOp())
135 v.getDefiningOp()->emitWarning()
136 <<
"Bitwidth " << bitWidth
137 <<
" too large (>16); abandoning analysis for this path";
141 uint64_t numRegStates = 1ULL << bitWidth;
142 for (
size_t i = 0; i < numRegStates; i++) {
143 possibleValues.insert(i);
147static void addPossibleValues(llvm::SetVector<size_t> &possibleValues,
149 llvm::DenseSet<Value> visited;
150 addPossibleValuesImpl(possibleValues, v, visited);
158static bool areStructurallyEquivalent(Value a, Value b, Region ®ionA,
163 Operation *opA =
a.getDefiningOp();
164 Operation *opB =
b.getDefiningOp();
168 bool aIsLocal = regionA.isAncestor(opA->getParentRegion());
169 bool bIsLocal = regionB.isAncestor(opB->getParentRegion());
170 if (aIsLocal != bIsLocal)
177 if (cast<OpResult>(a).getResultNumber() !=
178 cast<OpResult>(b).getResultNumber())
181 return OperationEquivalence::isEquivalentTo(
183 [&](Value lhs, Value rhs) -> LogicalResult {
184 return success(areStructurallyEquivalent(lhs, rhs, regionA, regionB));
186 nullptr, OperationEquivalence::Flags::IgnoreLocations);
192class GuardConditionFoldPattern :
public RewritePattern {
194 GuardConditionFoldPattern(MLIRContext *ctx,
195 ArrayRef<std::pair<Value, bool>> guardFacts,
196 Region &guardRegion, Region &actionRegion)
197 : RewritePattern(MatchAnyOpTypeTag(), 10, ctx),
198 guardFacts(guardFacts.begin(), guardFacts.
end()),
199 guardRegion(guardRegion), actionRegion(actionRegion) {}
201 LogicalResult matchAndRewrite(Operation *op,
202 PatternRewriter &rewriter)
const override {
203 if (!actionRegion.isAncestor(op->getParentRegion()))
208 if (isa<hw::ConstantOp>(op))
211 for (Value result : op->getResults()) {
212 if (!result.getType().isInteger(1) || result.use_empty())
215 for (
auto [guardExpr, isTrue] : guardFacts) {
217 if (areStructurallyEquivalent(guardExpr, result, guardRegion,
219 return replaceWithConstant(result, isTrue, op, rewriter);
223 if (
auto guardIcmp = guardExpr.getDefiningOp<ICmpOp>()) {
224 if (
auto actionIcmp = dyn_cast<ICmpOp>(op)) {
225 if (actionIcmp.getPredicate() ==
226 ICmpOp::getNegatedPredicate(guardIcmp.getPredicate()) &&
227 areStructurallyEquivalent(guardIcmp.getLhs(),
228 actionIcmp.getLhs(), guardRegion,
230 areStructurallyEquivalent(guardIcmp.getRhs(),
231 actionIcmp.getRhs(), guardRegion,
233 return replaceWithConstant(result, !isTrue, op, rewriter);
242 LogicalResult replaceWithConstant(Value result,
bool constVal, Operation *op,
243 PatternRewriter &rewriter)
const {
244 rewriter.setInsertionPointToStart(&actionRegion.front());
246 rewriter.getI1Type(), constVal ? 1 : 0);
247 rewriter.replaceAllUsesWith(result, c);
249 rewriter.eraseOp(op);
253 SmallVector<std::pair<Value, bool>> guardFacts;
255 Region &actionRegion;
264static void simplifyActionWithGuard(
TransitionOp transition,
265 OpBuilder &builder) {
266 Region &guardRegion = transition.getGuard();
267 Region &actionRegion = transition.getAction();
269 if (guardRegion.empty() || actionRegion.empty())
273 dyn_cast<fsm::ReturnOp>(guardRegion.front().getTerminator());
277 Location loc = guardReturn.getLoc();
278 Value guardCondition = guardReturn.getOperand();
281 SmallVector<std::pair<Value, bool>> guardFacts;
284 SmallVector<std::pair<Value, bool>> worklist;
285 worklist.push_back({guardCondition,
true});
287 while (!worklist.empty()) {
288 auto [cond, isTrue] = worklist.pop_back_val();
289 guardFacts.push_back({cond, isTrue});
292 if (
auto xorOp = cond.getDefiningOp<
XorOp>()) {
293 if (xorOp.isBinaryNot() &&
294 guardRegion.isAncestor(xorOp->getParentRegion()))
295 guardFacts.push_back({xorOp.getOperand(0), !isTrue});
300 if (
auto andOp = cond.getDefiningOp<
AndOp>()) {
301 if (guardRegion.isAncestor(andOp->getParentRegion())) {
302 for (Value operand : andOp.getOperands())
303 worklist.push_back({operand,
true});
311 for (
auto [guardExpr, isTrue] : guardFacts) {
312 bool guardIsExternal =
313 !guardExpr.getDefiningOp() ||
314 !guardRegion.isAncestor(guardExpr.getDefiningOp()->getParentRegion());
315 if (guardIsExternal) {
316 builder.setInsertionPointToStart(&actionRegion.front());
319 guardExpr.replaceUsesWithIf(constOp.getResult(), [&](OpOperand &use) {
320 return actionRegion.isAncestor(use.getOwner()->getParentRegion());
326 MLIRContext *ctx = transition.getContext();
328 patterns.add<GuardConditionFoldPattern>(ctx, guardFacts, guardRegion,
330 SmallVector<Operation *> actionOps;
331 actionRegion.walk([&](Operation *op) { actionOps.push_back(op); });
332 GreedyRewriteConfig config;
333 config.setScope(&actionRegion);
334 (void)applyOpPatternsGreedily(
335 actionOps, FrozenRewritePatternSet(std::move(
patterns)), config);
340static bool isConstantOrConstantTree(Value value) {
341 SmallVector<Value> worklist;
342 llvm::DenseSet<Value> visited;
344 worklist.push_back(value);
345 while (!worklist.empty()) {
346 Value current = worklist.pop_back_val();
349 if (!visited.insert(current).second)
352 Operation *definingOp = current.getDefiningOp();
356 if (isa<hw::ConstantOp>(definingOp))
359 if (
auto muxOp = dyn_cast<MuxOp>(definingOp)) {
360 worklist.push_back(muxOp.getTrueValue());
361 worklist.push_back(muxOp.getFalseValue());
376LogicalResult pushIcmp(ICmpOp op, PatternRewriter &rewriter) {
378 if (op.getPredicate() == ICmpPredicate::eq &&
379 op.getLhs().getDefiningOp<
MuxOp>() &&
380 (isConstantOrConstantTree(op.getLhs()) ||
382 rewriter.setInsertionPointAfter(op);
383 auto mux = op.getLhs().getDefiningOp<
MuxOp>();
384 Value x = mux.getTrueValue();
385 Value y = mux.getFalseValue();
386 Value
b = op.getRhs();
387 Location loc = op.getLoc();
388 auto eq1 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, x, b);
389 auto eq2 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, y, b);
390 rewriter.replaceOpWithNewOp<
MuxOp>(op, mux.getCond(), eq1.getResult(),
392 return llvm::success();
394 if (op.getPredicate() == ICmpPredicate::eq &&
395 op.getRhs().getDefiningOp<
MuxOp>() &&
396 (isConstantOrConstantTree(op.getRhs()) ||
398 rewriter.setInsertionPointAfter(op);
399 auto mux = op.getRhs().getDefiningOp<
MuxOp>();
400 Value x = mux.getTrueValue();
401 Value y = mux.getFalseValue();
402 Value
b = op.getLhs();
403 Location loc = op.getLoc();
404 auto eq1 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, x, b);
405 auto eq2 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, y, b);
406 rewriter.replaceOpWithNewOp<
MuxOp>(op, mux.getCond(), eq1.getResult(),
408 return llvm::success();
410 return llvm::failure();
415static void generateConcatenatedValues(
416 const llvm::SmallVector<llvm::SetVector<size_t>> &allOperandValues,
417 const llvm::SmallVector<unsigned> &shifts,
418 llvm::SetVector<size_t> &finalPossibleValues) {
420 if (allOperandValues.empty()) {
421 finalPossibleValues.insert(0);
426 llvm::SetVector<size_t> currentResults;
427 for (
size_t val : allOperandValues[0])
428 currentResults.insert(val << shifts[0]);
431 for (
size_t operandIdx = 1; operandIdx < allOperandValues.size();
433 llvm::SetVector<size_t> nextResults;
434 unsigned shift = shifts[operandIdx];
436 for (
size_t partialValue : currentResults) {
437 for (
size_t val : allOperandValues[operandIdx]) {
438 nextResults.insert(partialValue | (val << shift));
441 currentResults = std::move(nextResults);
444 finalPossibleValues = std::move(currentResults);
447static llvm::MapVector<Value, int> intToRegMap(SmallVector<seq::CompRegOp> v,
449 llvm::MapVector<Value, int> m;
450 for (
size_t ci = 0; ci < v.size(); ci++) {
452 int bits =
reg.getType().getIntOrFloatBitWidth();
453 int v = i & ((1 << bits) - 1);
460static int regMapToInt(SmallVector<seq::CompRegOp> v,
461 llvm::DenseMap<Value, int> m) {
464 for (
size_t ci = 0; ci < v.size(); ci++) {
466 i += m[
reg] * 1ULL << width;
467 width += (
reg.getType().getIntOrFloatBitWidth());
473static std::set<llvm::SmallVector<size_t>> calculateCartesianProduct(
474 const llvm::SmallVector<llvm::SetVector<size_t>> &valueSets) {
475 std::set<llvm::SmallVector<size_t>> product;
476 if (valueSets.empty()) {
485 for (
size_t value : valueSets.front()) {
486 product.insert({value});
492 for (
size_t i = 1; i < valueSets.size(); ++i) {
493 const auto ¤tSet = valueSets[i];
494 if (currentSet.empty()) {
499 std::set<llvm::SmallVector<size_t>> newProduct;
500 for (
const auto &existingVector : product) {
501 for (
size_t newValue : currentSet) {
502 llvm::SmallVector<size_t> newVector = existingVector;
503 newVector.push_back(newValue);
504 newProduct.insert(std::move(newVector));
507 product = std::move(newProduct);
513static FrozenRewritePatternSet loadPatterns(MLIRContext &
context) {
516 for (
auto *dialect :
context.getLoadedDialects())
517 dialect->getCanonicalizationPatterns(
patterns);
532 FrozenRewritePatternSet frozenPatterns(std::move(
patterns));
533 return frozenPatterns;
537getReachableStates(llvm::SetVector<size_t> &visitableStates,
538 HWModuleOp moduleOp,
size_t currentStateIndex,
539 SmallVector<seq::CompRegOp> registers, OpBuilder opBuilder,
540 bool isInitialState) {
544 llvm::dyn_cast<HWModuleOp>(opBuilder.clone(*moduleOp, mapping));
546 llvm::MapVector<Value, int> stateMap =
547 intToRegMap(registers, currentStateIndex);
548 Operation *terminator = clonedBody.getBody().front().getTerminator();
549 auto output = dyn_cast<hw::OutputOp>(terminator);
550 SmallVector<Value> values;
552 for (
auto [originalRegValue, constStateValue] : stateMap) {
554 Value clonedRegValue = mapping.lookup(originalRegValue);
555 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
556 auto reg = cast<seq::CompRegOp>(clonedRegOp);
557 Type constantType =
reg.getType();
558 IntegerAttr constantAttr =
559 opBuilder.getIntegerAttr(constantType, constStateValue);
560 opBuilder.setInsertionPoint(clonedRegOp);
561 auto otherStateConstant =
566 Value regInput =
reg.getInput();
567 if (regInput == clonedRegValue)
568 values.push_back(otherStateConstant.getResult());
570 values.push_back(regInput);
571 clonedRegValue.replaceAllUsesWith(otherStateConstant.getResult());
574 opBuilder.setInsertionPointToEnd(clonedBody.front().getBlock());
575 auto newOutput = hw::OutputOp::create(opBuilder, output.getLoc(), values);
577 FrozenRewritePatternSet frozenPatterns = loadPatterns(*moduleOp.getContext());
579 SmallVector<Operation *> opsToProcess;
580 clonedBody.walk([&](Operation *op) { opsToProcess.push_back(op); });
582 bool changed =
false;
583 GreedyRewriteConfig config;
584 if (failed(applyOpPatternsGreedily(opsToProcess, frozenPatterns, config,
588 llvm::SmallVector<llvm::SetVector<size_t>> pv;
589 for (
size_t j = 0; j < newOutput.getNumOperands(); j++) {
590 llvm::SetVector<size_t> possibleValues;
592 Value v = newOutput.getOperand(j);
593 addPossibleValues(possibleValues, v);
594 pv.push_back(possibleValues);
596 std::set<llvm::SmallVector<size_t>> flipped = calculateCartesianProduct(pv);
597 for (llvm::SmallVector<size_t> v : flipped) {
598 llvm::DenseMap<Value, int> m;
599 for (
size_t k = 0; k < v.size(); k++) {
604 int i = regMapToInt(registers, m);
605 visitableStates.insert(i);
613class HWModuleOpConverter {
615 HWModuleOpConverter(OpBuilder &builder,
HWModuleOp moduleOp,
616 ArrayRef<std::string> stateRegNames)
617 : moduleOp(moduleOp), opBuilder(builder), stateRegNames(stateRegNames) {}
618 LogicalResult
run() {
619 SmallVector<seq::CompRegOp> stateRegs;
620 SmallVector<seq::CompRegOp> variableRegs;
623 if (!isa<IntegerType>(
reg.getType())) {
624 reg.emitError(
"FSM extraction only supports integer-typed registers");
625 return WalkResult::interrupt();
627 if (isStateRegister(reg)) {
628 stateRegs.push_back(reg);
630 variableRegs.push_back(reg);
632 return WalkResult::advance();
634 if (walkResult.wasInterrupted())
636 if (stateRegs.empty()) {
637 emitError(moduleOp.getLoc())
638 <<
"Cannot find state register in this FSM. Use the state-regs "
639 "option to specify which registers are state registers.";
642 SmallVector<seq::CompRegOp> registers;
644 registers.push_back(c);
647 llvm::DenseMap<size_t, StateOp> stateToStateOp;
648 llvm::DenseMap<StateOp, size_t> stateOpToState;
653 llvm::DenseSet<size_t> asyncResetArguments;
654 Location loc = moduleOp.getLoc();
655 SmallVector<Type> inputTypes = moduleOp.getInputTypes();
658 auto resultTypes = moduleOp.getOutputTypes();
659 FunctionType machineType =
660 FunctionType::get(opBuilder.getContext(), inputTypes, resultTypes);
661 StringRef machineName = moduleOp.getName();
663 llvm::DenseMap<Value, int> initialStateMap;
665 Value resetValue =
reg.getResetValue();
671 reg.emitWarning(
"Assuming register with no reset starts with value 0");
675 if (!definingConstant) {
677 "cannot find defining constant for reset value of register");
681 definingConstant.getValueAttr().getValue().getZExtValue();
682 initialStateMap[
reg] = resetValueInt;
684 int initialStateIndex = regMapToInt(registers, initialStateMap);
686 std::string initialStateName =
"state_" + std::to_string(initialStateIndex);
689 SmallVector<NamedAttribute> machineAttrs;
690 if (
auto argNames = moduleOp->getAttrOfType<ArrayAttr>(
"argNames"))
691 machineAttrs.emplace_back(opBuilder.getStringAttr(
"argNames"), argNames);
692 if (
auto resNames = moduleOp->getAttrOfType<ArrayAttr>(
"resultNames"))
693 machineAttrs.emplace_back(opBuilder.getStringAttr(
"resNames"), resNames);
697 opBuilder.setInsertionPoint(moduleOp);
699 MachineOp::create(opBuilder, loc, machineName, initialStateName,
700 machineType, machineAttrs);
702 OpBuilder::InsertionGuard guard(opBuilder);
703 opBuilder.setInsertionPointToStart(&machine.getBody().front());
704 llvm::MapVector<seq::CompRegOp, VariableOp> variableMap;
706 TypedValue<Type> initialValue = varReg.getResetValue();
713 "Assuming register with no reset starts with value 0");
715 varReg.getType(), 0);
717 if (!definingConstant) {
718 varReg->emitError(
"cannot find defining constant for reset value of "
719 "variable register");
722 auto variableOp = VariableOp::create(
723 opBuilder, varReg->getLoc(), varReg.getInput().getType(),
724 definingConstant.getValueAttr(), varReg.getName().value_or(
"var"));
725 variableMap[varReg] = variableOp;
729 FrozenRewritePatternSet frozenPatterns =
730 loadPatterns(*moduleOp.getContext());
732 SetVector<int> reachableStates;
733 SmallVector<int> worklist;
735 worklist.push_back(initialStateIndex);
736 reachableStates.insert(initialStateIndex);
739 for (
unsigned i = 0; i < worklist.size(); ++i) {
741 int currentStateIndex = worklist[i];
743 llvm::MapVector<Value, int> stateMap =
744 intToRegMap(registers, currentStateIndex);
746 opBuilder.setInsertionPointToEnd(&machine.getBody().front());
750 if (!stateToStateOp.contains(currentStateIndex)) {
752 "state_" + std::to_string(currentStateIndex));
753 stateToStateOp.insert({currentStateIndex, stateOp});
754 stateOpToState.insert({stateOp, currentStateIndex});
756 stateOp = stateToStateOp.lookup(currentStateIndex);
758 Region &outputRegion = stateOp.getOutput();
759 Block *outputBlock = &outputRegion.front();
760 opBuilder.setInsertionPointToStart(outputBlock);
762 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), outputRegion,
763 outputBlock->getIterator(), mapping);
764 outputBlock->erase();
766 auto *terminator = outputRegion.front().getTerminator();
767 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
768 assert(hwOutputOp &&
"Expected terminator to be hw.output op");
772 OpBuilder::InsertionGuard stateGuard(opBuilder);
773 opBuilder.setInsertionPoint(hwOutputOp);
778 hwOutputOp.getOperands());
785 for (
auto &[originalRegValue, variableOp] : variableMap) {
786 Value clonedRegValue = mapping.lookup(originalRegValue);
787 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
788 auto reg = cast<seq::CompRegOp>(clonedRegOp);
789 const auto res = variableOp.getResult();
790 clonedRegValue.replaceAllUsesWith(res);
793 for (
auto const &[originalRegValue, constStateValue] : stateMap) {
795 Value clonedRegValue = mapping.lookup(originalRegValue);
797 "Original register value not found in mapping");
798 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
800 assert(clonedRegOp &&
"Cloned value must have a defining op");
801 opBuilder.setInsertionPoint(clonedRegOp);
802 auto r = cast<seq::CompRegOp>(clonedRegOp);
803 TypedValue<IntegerType> registerReset = r.getReset();
805 if (BlockArgument blockArg = dyn_cast<BlockArgument>(registerReset)) {
806 asyncResetArguments.insert(blockArg.getArgNumber());
808 opBuilder, blockArg.getLoc(), clonedRegValue.getType(), 0);
809 blockArg.replaceAllUsesWith(falseConst.getResult());
811 if (
auto xorOp = registerReset.getDefiningOp<
XorOp>()) {
812 if (xorOp.isBinaryNot()) {
813 Value rhs = xorOp.getOperand(0);
814 if (BlockArgument blockArg = dyn_cast<BlockArgument>(rhs)) {
815 asyncResetArguments.insert(blockArg.getArgNumber());
817 opBuilder, blockArg.getLoc(), blockArg.getType(), 1);
818 blockArg.replaceAllUsesWith(trueConst.getResult());
825 clonedRegValue.getType(), constStateValue);
826 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
827 clonedRegOp->erase();
829 GreedyRewriteConfig config;
830 SmallVector<Operation *> opsToProcess;
831 outputRegion.walk([&](Operation *op) { opsToProcess.push_back(op); });
834 for (
auto arg : outputRegion.front().getArguments()) {
835 int argIndex = arg.getArgNumber();
836 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
837 arg.replaceAllUsesWith(topLevelArg);
839 outputRegion.front().eraseArguments(
840 [](BlockArgument arg) {
return true; });
841 FrozenRewritePatternSet
patterns(opBuilder.getContext());
842 config.setScope(&outputRegion);
844 bool changed =
false;
845 if (failed(applyOpPatternsGreedily(opsToProcess,
patterns, config,
848 opBuilder.setInsertionPoint(stateOp);
853 bool sorted = sortTopologically(&outputRegion.front());
856 <<
"cannot convert module with combinational cycles to FSM";
859 Region &transitionRegion = stateOp.getTransitions();
860 llvm::SetVector<size_t> visitableStates;
861 if (failed(getReachableStates(visitableStates, moduleOp,
862 currentStateIndex, registers, opBuilder,
863 currentStateIndex == initialStateIndex)))
865 for (
size_t j : visitableStates) {
867 if (!stateToStateOp.contains(j)) {
868 opBuilder.setInsertionPointToEnd(&machine.getBody().front());
871 stateToStateOp.insert({j, toState});
872 stateOpToState.insert({toState, j});
874 toState = stateToStateOp[j];
876 opBuilder.setInsertionPointToStart(&transitionRegion.front());
879 Region &guardRegion = transitionOp.getGuard();
880 opBuilder.createBlock(&guardRegion);
882 Block &guardBlock = guardRegion.front();
884 opBuilder.setInsertionPointToStart(&guardBlock);
886 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), guardRegion,
887 guardBlock.getIterator(), mapping);
889 Block &newGuardBlock = guardRegion.front();
890 Operation *terminator = newGuardBlock.getTerminator();
891 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
892 assert(hwOutputOp &&
"Expected terminator to be hw.output op");
894 llvm::MapVector<Value, int> toStateMap = intToRegMap(registers, j);
895 SmallVector<Value> equalityChecks;
896 for (
auto &[originalRegValue, variableOp] : variableMap) {
897 opBuilder.setInsertionPointToStart(&newGuardBlock);
898 Value clonedRegValue = mapping.lookup(originalRegValue);
899 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
900 auto reg = cast<seq::CompRegOp>(clonedRegOp);
901 const auto res = variableOp.getResult();
902 clonedRegValue.replaceAllUsesWith(res);
905 for (
auto const &[originalRegValue, constStateValue] : toStateMap) {
907 Value clonedRegValue = mapping.lookup(originalRegValue);
908 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
909 opBuilder.setInsertionPoint(clonedRegOp);
910 auto r = cast<seq::CompRegOp>(clonedRegOp);
912 Value registerInput = r.getInput();
913 TypedValue<IntegerType> registerReset = r.getReset();
915 if (BlockArgument blockArg =
916 dyn_cast<BlockArgument>(registerReset)) {
918 opBuilder, blockArg.getLoc(), clonedRegValue.getType(), 0);
919 blockArg.replaceAllUsesWith(falseConst.getResult());
921 if (
auto xorOp = registerReset.getDefiningOp<
XorOp>()) {
922 if (xorOp.isBinaryNot()) {
923 Value rhs = xorOp.getOperand(0);
924 if (BlockArgument blockArg = dyn_cast<BlockArgument>(rhs)) {
926 opBuilder, blockArg.getLoc(), blockArg.getType(), 1);
927 blockArg.replaceAllUsesWith(trueConst.getResult());
932 Type constantType = registerInput.getType();
933 IntegerAttr constantAttr =
934 opBuilder.getIntegerAttr(constantType, constStateValue);
936 opBuilder, hwOutputOp.getLoc(), constantAttr);
939 ICmpOp::create(opBuilder, hwOutputOp.getLoc(), ICmpPredicate::eq,
940 registerInput, otherStateConstant.getResult());
941 equalityChecks.push_back(doesEqual.getResult());
943 opBuilder.setInsertionPoint(hwOutputOp);
944 auto allEqualCheck = AndOp::create(opBuilder, hwOutputOp.getLoc(),
945 equalityChecks,
false);
946 fsm::ReturnOp::create(opBuilder, hwOutputOp.getLoc(),
947 allEqualCheck.getResult());
949 for (BlockArgument arg : newGuardBlock.getArguments()) {
950 int argIndex = arg.getArgNumber();
951 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
952 arg.replaceAllUsesWith(topLevelArg);
954 newGuardBlock.eraseArguments([](BlockArgument arg) {
return true; });
955 llvm::MapVector<Value, int> fromStateMap =
956 intToRegMap(registers, currentStateIndex);
957 for (
auto const &[originalRegValue, constStateValue] : fromStateMap) {
958 Value clonedRegValue = mapping.lookup(originalRegValue);
960 "Original register value not found in mapping");
961 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
962 assert(clonedRegOp &&
"Cloned value must have a defining op");
963 opBuilder.setInsertionPoint(clonedRegOp);
966 clonedRegValue.getType(), constStateValue);
967 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
968 clonedRegOp->erase();
970 Region &actionRegion = transitionOp.getAction();
971 if (!variableRegs.empty()) {
972 Block *actionBlock = opBuilder.createBlock(&actionRegion);
973 opBuilder.setInsertionPointToStart(actionBlock);
975 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), actionRegion,
976 actionBlock->getIterator(), mapping);
977 actionBlock->erase();
978 Block &newActionBlock = actionRegion.front();
979 for (BlockArgument arg : newActionBlock.getArguments()) {
980 int argIndex = arg.getArgNumber();
981 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
982 arg.replaceAllUsesWith(topLevelArg);
984 newActionBlock.eraseArguments([](BlockArgument arg) {
return true; });
985 for (
auto &[originalRegValue, variableOp] : variableMap) {
986 Value clonedRegValue = mapping.lookup(originalRegValue);
987 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
988 auto reg = cast<seq::CompRegOp>(clonedRegOp);
989 opBuilder.setInsertionPointToStart(&newActionBlock);
990 UpdateOp::create(opBuilder,
reg.getLoc(), variableOp,
992 const Value res = variableOp.getResult();
993 clonedRegValue.replaceAllUsesWith(res);
996 Operation *terminator = actionRegion.back().getTerminator();
997 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
998 assert(hwOutputOp &&
"Expected terminator to be hw.output op");
1001 for (
auto const &[originalRegValue, constStateValue] : fromStateMap) {
1002 Value clonedRegValue = mapping.lookup(originalRegValue);
1003 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
1004 opBuilder.setInsertionPoint(clonedRegOp);
1006 opBuilder, clonedRegValue.getLoc(), clonedRegValue.getType(),
1008 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
1009 clonedRegOp->erase();
1012 GreedyRewriteConfig config;
1013 SmallVector<Operation *> opsToProcess;
1014 actionRegion.walk([&](Operation *op) { opsToProcess.push_back(op); });
1015 config.setScope(&actionRegion);
1017 bool changed =
false;
1018 if (failed(applyOpPatternsGreedily(opsToProcess, frozenPatterns,
1025 bool actionSorted = sortTopologically(&actionRegion.front());
1026 if (!actionSorted) {
1027 moduleOp.emitError()
1028 <<
"cannot convert module with combinational cycles to FSM";
1036 bool guardSorted = sortTopologically(&newGuardBlock);
1038 moduleOp.emitError()
1039 <<
"cannot convert module with combinational cycles to FSM";
1042 SmallVector<Operation *> outputOps;
1043 stateOp.getOutput().walk(
1044 [&](Operation *op) { outputOps.push_back(op); });
1046 bool changed =
false;
1047 GreedyRewriteConfig config;
1048 config.setScope(&stateOp.getOutput());
1049 LogicalResult converged = applyOpPatternsGreedily(
1050 outputOps, frozenPatterns, config, &changed);
1051 assert(succeeded(converged) &&
"canonicalization failed to converge");
1052 SmallVector<Operation *> transitionOps;
1053 stateOp.getTransitions().walk(
1054 [&](Operation *op) { transitionOps.push_back(op); });
1056 GreedyRewriteConfig config2;
1057 config2.setScope(&stateOp.getTransitions());
1058 if (failed(applyOpPatternsGreedily(transitionOps, frozenPatterns,
1059 config2, &changed))) {
1068 simplifyActionWithGuard(transition, opBuilder);
1072 SmallVector<Operation *> postOps;
1073 stateOp.getTransitions().walk(
1074 [&](Operation *op) { postOps.push_back(op); });
1075 GreedyRewriteConfig postConfig;
1076 postConfig.setScope(&stateOp.getTransitions());
1077 if (failed(applyOpPatternsGreedily(postOps, frozenPatterns,
1078 postConfig, &changed)))
1084 StateOp nextState = transition.getNextStateOp();
1085 int nextStateIndex = stateOpToState.lookup(nextState);
1086 auto guardConst = transition.getGuardReturn()
1089 bool nextStateIsReachable =
1090 !guardConst || (guardConst.getValueAttr().getInt() != 0);
1093 if (nextStateIsReachable &&
1094 !reachableStates.contains(nextStateIndex)) {
1095 worklist.push_back(nextStateIndex);
1096 reachableStates.insert(nextStateIndex);
1105 SmallVector<StateOp> statesToErase;
1109 if (!stateOp.getOutputOp()) {
1110 statesToErase.push_back(stateOp);
1118 for (
StateOp stateOp : statesToErase) {
1120 if (transition.getNextStateOp().getSymName() == stateOp.getSymName()) {
1127 llvm::DenseSet<BlockArgument> asyncResetBlockArguments;
1128 for (
auto arg : machine.getBody().front().getArguments()) {
1129 if (asyncResetArguments.contains(arg.getArgNumber())) {
1130 asyncResetBlockArguments.insert(arg);
1138 if (!asyncResetBlockArguments.empty()) {
1139 moduleOp.emitWarning()
1140 <<
"reset signals detected and removed from FSM; "
1141 "reset behavior is captured only in the initial state";
1144 Block &front = machine.getBody().front();
1145 front.eraseArguments([&](BlockArgument arg) {
1146 return asyncResetBlockArguments.contains(arg);
1148 machine.getBody().front().eraseArguments([&](BlockArgument arg) {
1149 return arg.getType() == seq::ClockType::get(arg.getContext());
1151 FunctionType oldFunctionType = machine.getFunctionType();
1152 SmallVector<Type> inputsWithoutClock;
1153 for (
unsigned int i = 0; i < oldFunctionType.getNumInputs(); i++) {
1154 Type input = oldFunctionType.getInput(i);
1155 if (input != seq::ClockType::get(input.getContext()) &&
1156 !asyncResetArguments.contains(i))
1157 inputsWithoutClock.push_back(input);
1160 FunctionType newFunctionType = FunctionType::get(
1161 opBuilder.getContext(), inputsWithoutClock, resultTypes);
1163 machine.setFunctionType(newFunctionType);
1171 auto regName =
reg.getName();
1177 if (!stateRegNames.empty()) {
1178 return llvm::is_contained(stateRegNames, regName->str());
1183 return regName->contains(
"state");
1187 OpBuilder &opBuilder;
1188 ArrayRef<std::string> stateRegNames;
1194struct CoreToFSMPass :
public circt::impl::ConvertCoreToFSMBase<CoreToFSMPass> {
1195 using ConvertCoreToFSMBase<CoreToFSMPass>::ConvertCoreToFSMBase;
1197 void runOnOperation()
override {
1198 auto module = getOperation();
1199 OpBuilder builder(module);
1201 SmallVector<HWModuleOp> modules;
1202 for (
auto hwModule : module.getOps<
HWModuleOp>()) {
1203 modules.push_back(hwModule);
1207 for (
auto hwModule : modules) {
1208 for (
auto instance : hwModule.getOps<
hw::InstanceOp>()) {
1209 instance.emitError() <<
"instance conversion is not yet supported";
1210 signalPassFailure();
1215 for (
auto hwModule : modules) {
1216 builder.setInsertionPoint(hwModule);
1217 HWModuleOpConverter converter(builder, hwModule, stateRegs);
1218 if (failed(converter.run())) {
1219 signalPassFailure();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)