18#include "mlir/Analysis/TopologicalSortUtils.h"
19#include "mlir/IR/Block.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinAttributes.h"
22#include "mlir/IR/BuiltinOps.h"
23#include "mlir/IR/Diagnostics.h"
24#include "mlir/IR/IRMapping.h"
25#include "mlir/IR/MLIRContext.h"
26#include "mlir/IR/Value.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Pass/PassManager.h"
29#include "mlir/Transforms/DialectConversion.h"
30#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31#include "mlir/Transforms/Passes.h"
32#include "llvm/ADT/DenseMap.h"
33#include "llvm/ADT/DenseSet.h"
34#include "llvm/ADT/SmallVector.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/LogicalResult.h"
40#define GEN_PASS_DEF_CONVERTCORETOFSM
41#include "circt/Conversion/Passes.h.inc"
53static void generateConcatenatedValues(
54 const llvm::SmallVector<llvm::SetVector<size_t>> &allOperandValues,
55 const llvm::SmallVector<unsigned> &shifts,
56 llvm::SetVector<size_t> &finalPossibleValues);
59static void addPossibleValuesImpl(llvm::SetVector<size_t> &possibleValues,
60 Value v, llvm::DenseSet<Value> &visited) {
62 if (!visited.insert(v).second)
65 if (
auto c = dyn_cast_or_null<hw::ConstantOp>(v.getDefiningOp())) {
66 possibleValues.insert(c.getValueAttr().getValue().getZExtValue());
69 if (
auto m = dyn_cast_or_null<MuxOp>(v.getDefiningOp())) {
70 addPossibleValuesImpl(possibleValues, m.getTrueValue(), visited);
71 addPossibleValuesImpl(possibleValues, m.getFalseValue(), visited);
75 if (
auto concatOp = dyn_cast_or_null<ConcatOp>(v.getDefiningOp())) {
76 llvm::SmallVector<llvm::SetVector<size_t>> allOperandValues;
77 llvm::SmallVector<unsigned> operandWidths;
79 for (Value operand : concatOp.getOperands()) {
80 llvm::SetVector<size_t> operandPossibleValues;
81 addPossibleValuesImpl(operandPossibleValues, operand, visited);
86 auto opType = dyn_cast<IntegerType>(operand.getType());
89 assert(opType &&
"comb.concat operand must be an integer type");
90 unsigned width = opType.getWidth();
91 if (operandPossibleValues.empty()) {
92 uint64_t numStates = 1ULL << width;
95 if (numStates > 256) {
99 v.getDefiningOp()->emitWarning()
100 <<
"Search space too large (>" << 256
101 <<
" states) for operand with bitwidth " << width
102 <<
"; abandoning analysis for this path";
105 for (uint64_t i = 0; i < numStates; ++i)
106 operandPossibleValues.insert(i);
109 allOperandValues.push_back(operandPossibleValues);
110 operandWidths.push_back(width);
115 llvm::SmallVector<unsigned> shifts(concatOp.getNumOperands(), 0);
116 for (
int i = concatOp.getNumOperands() - 2; i >= 0; --i) {
117 shifts[i] = shifts[i + 1] + operandWidths[i + 1];
120 generateConcatenatedValues(allOperandValues, shifts, possibleValues);
128 auto addrType = dyn_cast<IntegerType>(v.getType());
132 unsigned bitWidth = addrType.getWidth();
135 if (v.getDefiningOp())
136 v.getDefiningOp()->emitWarning()
137 <<
"Bitwidth " << bitWidth
138 <<
" too large (>16); abandoning analysis for this path";
142 uint64_t numRegStates = 1ULL << bitWidth;
143 for (
size_t i = 0; i < numRegStates; i++) {
144 possibleValues.insert(i);
148static void addPossibleValues(llvm::SetVector<size_t> &possibleValues,
150 llvm::DenseSet<Value> visited;
151 addPossibleValuesImpl(possibleValues, v, visited);
159static bool areStructurallyEquivalent(Value a, Value b, Region ®ionA,
164 Operation *opA =
a.getDefiningOp();
165 Operation *opB =
b.getDefiningOp();
169 bool aIsLocal = regionA.isAncestor(opA->getParentRegion());
170 bool bIsLocal = regionB.isAncestor(opB->getParentRegion());
171 if (aIsLocal != bIsLocal)
178 if (cast<OpResult>(a).getResultNumber() !=
179 cast<OpResult>(b).getResultNumber())
182 return OperationEquivalence::isEquivalentTo(
184 [&](Value lhs, Value rhs) -> LogicalResult {
185 return success(areStructurallyEquivalent(lhs, rhs, regionA, regionB));
187 nullptr, OperationEquivalence::Flags::IgnoreLocations);
193class GuardConditionFoldPattern :
public RewritePattern {
195 GuardConditionFoldPattern(MLIRContext *ctx,
196 ArrayRef<std::pair<Value, bool>> guardFacts,
197 Region &guardRegion, Region &actionRegion)
198 : RewritePattern(MatchAnyOpTypeTag(), 10, ctx),
199 guardFacts(guardFacts.begin(), guardFacts.
end()),
200 guardRegion(guardRegion), actionRegion(actionRegion) {}
202 LogicalResult matchAndRewrite(Operation *op,
203 PatternRewriter &rewriter)
const override {
204 if (!actionRegion.isAncestor(op->getParentRegion()))
209 if (isa<hw::ConstantOp>(op))
212 for (Value result : op->getResults()) {
213 if (!result.getType().isInteger(1) || result.use_empty())
216 for (
auto [guardExpr, isTrue] : guardFacts) {
218 if (areStructurallyEquivalent(guardExpr, result, guardRegion,
220 return replaceWithConstant(result, isTrue, op, rewriter);
224 if (
auto guardIcmp = guardExpr.getDefiningOp<ICmpOp>()) {
225 if (
auto actionIcmp = dyn_cast<ICmpOp>(op)) {
226 if (actionIcmp.getPredicate() ==
227 ICmpOp::getNegatedPredicate(guardIcmp.getPredicate()) &&
228 areStructurallyEquivalent(guardIcmp.getLhs(),
229 actionIcmp.getLhs(), guardRegion,
231 areStructurallyEquivalent(guardIcmp.getRhs(),
232 actionIcmp.getRhs(), guardRegion,
234 return replaceWithConstant(result, !isTrue, op, rewriter);
243 LogicalResult replaceWithConstant(Value result,
bool constVal, Operation *op,
244 PatternRewriter &rewriter)
const {
245 rewriter.setInsertionPointToStart(&actionRegion.front());
247 rewriter.getI1Type(), constVal ? 1 : 0);
248 rewriter.replaceAllUsesWith(result, c);
250 rewriter.eraseOp(op);
254 SmallVector<std::pair<Value, bool>> guardFacts;
256 Region &actionRegion;
265static void simplifyActionWithGuard(
TransitionOp transition,
266 OpBuilder &builder) {
267 Region &guardRegion = transition.getGuard();
268 Region &actionRegion = transition.getAction();
270 if (guardRegion.empty() || actionRegion.empty())
274 dyn_cast<fsm::ReturnOp>(guardRegion.front().getTerminator());
278 Location loc = guardReturn.getLoc();
279 Value guardCondition = guardReturn.getOperand();
282 SmallVector<std::pair<Value, bool>> guardFacts;
285 SmallVector<std::pair<Value, bool>> worklist;
286 worklist.push_back({guardCondition,
true});
288 while (!worklist.empty()) {
289 auto [cond, isTrue] = worklist.pop_back_val();
290 guardFacts.push_back({cond, isTrue});
293 if (
auto xorOp = cond.getDefiningOp<
XorOp>()) {
294 if (xorOp.isBinaryNot() &&
295 guardRegion.isAncestor(xorOp->getParentRegion()))
296 guardFacts.push_back({xorOp.getOperand(0), !isTrue});
301 if (
auto andOp = cond.getDefiningOp<
AndOp>()) {
302 if (guardRegion.isAncestor(andOp->getParentRegion())) {
303 for (Value operand : andOp.getOperands())
304 worklist.push_back({operand,
true});
312 for (
auto [guardExpr, isTrue] : guardFacts) {
313 bool guardIsExternal =
314 !guardExpr.getDefiningOp() ||
315 !guardRegion.isAncestor(guardExpr.getDefiningOp()->getParentRegion());
316 if (guardIsExternal) {
317 builder.setInsertionPointToStart(&actionRegion.front());
320 guardExpr.replaceUsesWithIf(constOp.getResult(), [&](OpOperand &use) {
321 if (!actionRegion.isAncestor(use.getOwner()->getParentRegion()))
325 if (auto updateOp = dyn_cast<UpdateOp>(use.getOwner()))
326 if (use.getOperandNumber() == 0)
334 MLIRContext *ctx = transition.getContext();
336 patterns.add<GuardConditionFoldPattern>(ctx, guardFacts, guardRegion,
338 SmallVector<Operation *> actionOps;
339 actionRegion.walk([&](Operation *op) { actionOps.push_back(op); });
340 GreedyRewriteConfig config;
341 config.setScope(&actionRegion);
342 (void)applyOpPatternsGreedily(
343 actionOps, FrozenRewritePatternSet(std::move(
patterns)), config);
348static bool isConstantOrConstantTree(Value value) {
349 SmallVector<Value> worklist;
350 llvm::DenseSet<Value> visited;
352 worklist.push_back(value);
353 while (!worklist.empty()) {
354 Value current = worklist.pop_back_val();
357 if (!visited.insert(current).second)
360 Operation *definingOp = current.getDefiningOp();
364 if (isa<hw::ConstantOp>(definingOp))
367 if (
auto muxOp = dyn_cast<MuxOp>(definingOp)) {
368 worklist.push_back(muxOp.getTrueValue());
369 worklist.push_back(muxOp.getFalseValue());
384LogicalResult pushIcmp(ICmpOp op, PatternRewriter &rewriter) {
386 if (op.getPredicate() == ICmpPredicate::eq &&
387 op.getLhs().getDefiningOp<
MuxOp>() &&
388 (isConstantOrConstantTree(op.getLhs()) ||
390 rewriter.setInsertionPointAfter(op);
391 auto mux = op.getLhs().getDefiningOp<
MuxOp>();
392 Value x = mux.getTrueValue();
393 Value y = mux.getFalseValue();
394 Value
b = op.getRhs();
395 Location loc = op.getLoc();
396 auto eq1 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, x, b);
397 auto eq2 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, y, b);
398 rewriter.replaceOpWithNewOp<
MuxOp>(op, mux.getCond(), eq1.getResult(),
400 return llvm::success();
402 if (op.getPredicate() == ICmpPredicate::eq &&
403 op.getRhs().getDefiningOp<
MuxOp>() &&
404 (isConstantOrConstantTree(op.getRhs()) ||
406 rewriter.setInsertionPointAfter(op);
407 auto mux = op.getRhs().getDefiningOp<
MuxOp>();
408 Value x = mux.getTrueValue();
409 Value y = mux.getFalseValue();
410 Value
b = op.getLhs();
411 Location loc = op.getLoc();
412 auto eq1 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, x, b);
413 auto eq2 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, y, b);
414 rewriter.replaceOpWithNewOp<
MuxOp>(op, mux.getCond(), eq1.getResult(),
416 return llvm::success();
418 return llvm::failure();
423static void generateConcatenatedValues(
424 const llvm::SmallVector<llvm::SetVector<size_t>> &allOperandValues,
425 const llvm::SmallVector<unsigned> &shifts,
426 llvm::SetVector<size_t> &finalPossibleValues) {
428 if (allOperandValues.empty()) {
429 finalPossibleValues.insert(0);
434 llvm::SetVector<size_t> currentResults;
435 for (
size_t val : allOperandValues[0])
436 currentResults.insert(val << shifts[0]);
439 for (
size_t operandIdx = 1; operandIdx < allOperandValues.size();
441 llvm::SetVector<size_t> nextResults;
442 unsigned shift = shifts[operandIdx];
444 for (
size_t partialValue : currentResults) {
445 for (
size_t val : allOperandValues[operandIdx]) {
446 nextResults.insert(partialValue | (val << shift));
449 currentResults = std::move(nextResults);
452 for (
size_t val : currentResults)
453 finalPossibleValues.insert(val);
459 for (
size_t ci = 0; ci < v.size(); ci++) {
461 int bits =
reg.getType().getIntOrFloatBitWidth();
462 int v = i & ((1 << bits) - 1);
469static int regMapToInt(SmallVector<seq::CompRegOp> v,
470 llvm::DenseMap<Value, int> m) {
473 for (
size_t ci = 0; ci < v.size(); ci++) {
475 i += m[
reg] * 1ULL << width;
476 width += (
reg.getType().getIntOrFloatBitWidth());
482static std::set<llvm::SmallVector<size_t>> calculateCartesianProduct(
483 const llvm::SmallVector<llvm::SetVector<size_t>> &valueSets) {
484 std::set<llvm::SmallVector<size_t>> product;
485 if (valueSets.empty()) {
494 for (
size_t value : valueSets.front()) {
495 product.insert({value});
501 for (
size_t i = 1; i < valueSets.size(); ++i) {
502 const auto ¤tSet = valueSets[i];
503 if (currentSet.empty()) {
508 std::set<llvm::SmallVector<size_t>> newProduct;
509 for (
const auto &existingVector : product) {
510 for (
size_t newValue : currentSet) {
511 llvm::SmallVector<size_t> newVector = existingVector;
512 newVector.push_back(newValue);
513 newProduct.insert(std::move(newVector));
516 product = std::move(newProduct);
522static FrozenRewritePatternSet loadPatterns(MLIRContext &
context) {
525 for (
auto *dialect :
context.getLoadedDialects())
526 dialect->getCanonicalizationPatterns(
patterns);
541 FrozenRewritePatternSet frozenPatterns(std::move(
patterns));
542 return frozenPatterns;
546getReachableStates(llvm::SetVector<size_t> &visitableStates,
547 HWModuleOp moduleOp,
size_t currentStateIndex,
548 SmallVector<seq::CompRegOp> registers) {
553 mlir::ModuleOp::create(moduleOp.getLoc());
554 OpBuilder
b(moduleOp.getContext());
555 b.setInsertionPointToStart(analysisModule->getBody());
558 auto clonedBody = llvm::dyn_cast<HWModuleOp>(
b.clone(*moduleOp, mapping));
561 intToRegMap(registers, currentStateIndex);
562 Operation *terminator = clonedBody.getBody().front().getTerminator();
563 auto output = dyn_cast<hw::OutputOp>(terminator);
564 SmallVector<Value> values;
566 for (
auto [originalRegValue, constStateValue] : stateMap) {
568 Value clonedRegValue = mapping.lookup(originalRegValue);
569 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
570 auto reg = cast<seq::CompRegOp>(clonedRegOp);
571 Type constantType =
reg.getType();
572 IntegerAttr constantAttr =
b.getIntegerAttr(constantType, constStateValue);
573 b.setInsertionPoint(clonedRegOp);
574 auto otherStateConstant =
579 Value regInput =
reg.getInput();
580 if (regInput == clonedRegValue)
581 values.push_back(otherStateConstant.getResult());
583 values.push_back(regInput);
584 clonedRegValue.replaceAllUsesWith(otherStateConstant.getResult());
587 b.setInsertionPointToEnd(clonedBody.front().getBlock());
588 auto newOutput = hw::OutputOp::create(b, output.getLoc(), values);
593 SmallVector<hw::ModulePort> newPorts;
596 newPorts.push_back(p);
597 for (
auto [i, val] :
llvm::enumerate(values))
598 newPorts.push_back({
b.getStringAttr(
"out" + std::to_string(i)),
599 val.getType(), hw::ModulePort::Direction::Output});
600 clonedBody.setHWModuleType(hw::ModuleType::get(
b.getContext(), newPorts));
602 FrozenRewritePatternSet frozenPatterns = loadPatterns(*moduleOp.getContext());
604 SmallVector<Operation *> opsToProcess;
605 clonedBody.walk([&](Operation *op) { opsToProcess.push_back(op); });
607 bool changed =
false;
608 GreedyRewriteConfig config;
609 if (failed(applyOpPatternsGreedily(opsToProcess, frozenPatterns, config,
613 llvm::SmallVector<llvm::SetVector<size_t>> pv;
614 for (
size_t j = 0; j < newOutput.getNumOperands(); j++) {
615 llvm::SetVector<size_t> possibleValues;
617 Value v = newOutput.getOperand(j);
618 addPossibleValues(possibleValues, v);
619 pv.push_back(possibleValues);
621 std::set<llvm::SmallVector<size_t>> flipped = calculateCartesianProduct(pv);
622 for (llvm::SmallVector<size_t> v : flipped) {
623 llvm::DenseMap<Value, int> m;
624 for (
size_t k = 0; k < v.size(); k++) {
629 int i = regMapToInt(registers, m);
630 visitableStates.insert(i);
638class HWModuleOpConverter {
640 HWModuleOpConverter(OpBuilder &builder,
HWModuleOp moduleOp,
641 ArrayRef<std::string> stateRegNames)
642 : moduleOp(moduleOp), opBuilder(builder), stateRegNames(stateRegNames) {}
643 LogicalResult
run() {
644 SmallVector<seq::CompRegOp> stateRegs;
645 SmallVector<seq::CompRegOp> variableRegs;
646 Value foundClock, foundReset =
nullptr;
649 auto reset =
reg.getReset();
651 if (clk != foundClock) {
652 reg.emitError(
"All registers must have the same clock signal.");
653 return WalkResult::interrupt();
661 if (reset != foundReset) {
662 reg.emitError(
"All registers must have the same reset signal.");
663 return WalkResult::interrupt();
671 if (!isa<IntegerType>(
reg.getType())) {
672 reg.emitError(
"FSM extraction only supports integer-typed registers");
673 return WalkResult::interrupt();
675 if (isStateRegister(reg)) {
676 stateRegs.push_back(reg);
678 variableRegs.push_back(reg);
680 return WalkResult::advance();
682 if (walkResult.wasInterrupted())
684 if (stateRegs.empty()) {
685 emitError(moduleOp.getLoc())
686 <<
"Cannot find state register in this FSM. Use the state-regs "
687 "option to specify which registers are state registers.";
690 SmallVector<seq::CompRegOp> registers;
692 registers.push_back(c);
695 llvm::DenseMap<size_t, StateOp> stateToStateOp;
696 llvm::DenseMap<StateOp, size_t> stateOpToState;
701 llvm::DenseSet<size_t> asyncResetArguments;
702 Location loc = moduleOp.getLoc();
703 SmallVector<Type> inputTypes = moduleOp.getInputTypes();
706 auto resultTypes = moduleOp.getOutputTypes();
707 FunctionType machineType =
708 FunctionType::get(opBuilder.getContext(), inputTypes, resultTypes);
709 StringRef machineName = moduleOp.getName();
711 llvm::DenseMap<Value, int> initialStateMap;
713 Value resetValue =
reg.getResetValue();
719 reg.emitWarning(
"Assuming register with no reset starts with value 0");
723 if (!definingConstant) {
725 "cannot find defining constant for reset value of register");
729 definingConstant.getValueAttr().getValue().getZExtValue();
730 initialStateMap[
reg] = resetValueInt;
732 int initialStateIndex = regMapToInt(registers, initialStateMap);
734 std::string initialStateName =
"state_" + std::to_string(initialStateIndex);
737 SmallVector<NamedAttribute> machineAttrs;
738 if (
auto argNames = moduleOp->getAttrOfType<ArrayAttr>(
"argNames"))
739 machineAttrs.emplace_back(opBuilder.getStringAttr(
"argNames"), argNames);
740 if (
auto resNames = moduleOp->getAttrOfType<ArrayAttr>(
"resultNames"))
741 machineAttrs.emplace_back(opBuilder.getStringAttr(
"resNames"), resNames);
745 opBuilder.setInsertionPoint(moduleOp);
747 MachineOp::create(opBuilder, loc, machineName, initialStateName,
748 machineType, machineAttrs);
750 OpBuilder::InsertionGuard guard(opBuilder);
751 opBuilder.setInsertionPointToStart(&machine.getBody().front());
754 TypedValue<Type> initialValue = varReg.getResetValue();
761 "Assuming register with no reset starts with value 0");
763 varReg.getType(), 0);
765 if (!definingConstant) {
766 varReg->emitError(
"cannot find defining constant for reset value of "
767 "variable register");
770 auto variableOp = VariableOp::create(
771 opBuilder, varReg->getLoc(), varReg.getInput().getType(),
772 definingConstant.getValueAttr(), varReg.getName().value_or(
"var"));
773 variableMap[varReg] = variableOp;
777 FrozenRewritePatternSet frozenPatterns =
778 loadPatterns(*moduleOp.getContext());
780 SetVector<int> reachableStates;
781 SmallVector<int> worklist;
783 worklist.push_back(initialStateIndex);
784 reachableStates.insert(initialStateIndex);
787 for (
unsigned i = 0; i < worklist.size(); ++i) {
789 int currentStateIndex = worklist[i];
792 intToRegMap(registers, currentStateIndex);
794 opBuilder.setInsertionPointToEnd(&machine.getBody().front());
798 if (!stateToStateOp.contains(currentStateIndex)) {
800 "state_" + std::to_string(currentStateIndex));
801 stateToStateOp.insert({currentStateIndex, stateOp});
802 stateOpToState.insert({stateOp, currentStateIndex});
804 stateOp = stateToStateOp.lookup(currentStateIndex);
806 Region &outputRegion = stateOp.getOutput();
807 Block *outputBlock = &outputRegion.front();
808 opBuilder.setInsertionPointToStart(outputBlock);
810 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), outputRegion,
811 outputBlock->getIterator(), mapping);
812 outputBlock->erase();
814 auto *terminator = outputRegion.front().getTerminator();
815 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
816 assert(hwOutputOp &&
"Expected terminator to be hw.output op");
820 OpBuilder::InsertionGuard stateGuard(opBuilder);
821 opBuilder.setInsertionPoint(hwOutputOp);
826 hwOutputOp.getOperands());
833 for (
auto &[originalRegValue, variableOp] : variableMap) {
834 Value clonedRegValue = mapping.lookup(originalRegValue);
835 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
836 auto reg = cast<seq::CompRegOp>(clonedRegOp);
837 const auto res = variableOp.getResult();
838 clonedRegValue.replaceAllUsesWith(res);
841 for (
auto const &[originalRegValue, constStateValue] : stateMap) {
843 Value clonedRegValue = mapping.lookup(originalRegValue);
845 "Original register value not found in mapping");
846 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
848 assert(clonedRegOp &&
"Cloned value must have a defining op");
849 opBuilder.setInsertionPoint(clonedRegOp);
850 auto r = cast<seq::CompRegOp>(clonedRegOp);
851 TypedValue<IntegerType> registerReset = r.getReset();
853 if (BlockArgument blockArg = dyn_cast<BlockArgument>(registerReset)) {
854 asyncResetArguments.insert(blockArg.getArgNumber());
856 opBuilder, blockArg.getLoc(), clonedRegValue.getType(), 0);
857 blockArg.replaceAllUsesWith(falseConst.getResult());
859 if (
auto xorOp = registerReset.getDefiningOp<
XorOp>()) {
860 if (xorOp.isBinaryNot()) {
861 Value rhs = xorOp.getOperand(0);
862 if (BlockArgument blockArg = dyn_cast<BlockArgument>(rhs)) {
863 asyncResetArguments.insert(blockArg.getArgNumber());
865 opBuilder, blockArg.getLoc(), blockArg.getType(), 1);
866 blockArg.replaceAllUsesWith(trueConst.getResult());
873 clonedRegValue.getType(), constStateValue);
874 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
875 clonedRegOp->erase();
877 GreedyRewriteConfig config;
878 SmallVector<Operation *> opsToProcess;
879 outputRegion.walk([&](Operation *op) { opsToProcess.push_back(op); });
882 for (
auto arg : outputRegion.front().getArguments()) {
883 int argIndex = arg.getArgNumber();
884 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
885 arg.replaceAllUsesWith(topLevelArg);
887 outputRegion.front().eraseArguments(
888 [](BlockArgument arg) {
return true; });
893 bool sorted = sortTopologically(&outputRegion.front());
896 <<
"cannot convert module with combinational cycles to FSM";
899 FrozenRewritePatternSet
patterns(opBuilder.getContext());
900 config.setScope(&outputRegion);
902 bool changed =
false;
903 if (failed(applyOpPatternsGreedily(opsToProcess,
patterns, config,
906 opBuilder.setInsertionPoint(stateOp);
907 Region &transitionRegion = stateOp.getTransitions();
908 llvm::SetVector<size_t> visitableStates;
909 if (failed(getReachableStates(visitableStates, moduleOp,
910 currentStateIndex, registers)))
912 for (
size_t j : visitableStates) {
914 if (!stateToStateOp.contains(j)) {
915 opBuilder.setInsertionPointToEnd(&machine.getBody().front());
918 stateToStateOp.insert({j, toState});
919 stateOpToState.insert({toState, j});
921 toState = stateToStateOp[j];
923 opBuilder.setInsertionPointToStart(&transitionRegion.front());
926 Region &guardRegion = transitionOp.getGuard();
927 opBuilder.createBlock(&guardRegion);
929 Block &guardBlock = guardRegion.front();
931 opBuilder.setInsertionPointToStart(&guardBlock);
933 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), guardRegion,
934 guardBlock.getIterator(), mapping);
936 Block &newGuardBlock = guardRegion.front();
937 Operation *terminator = newGuardBlock.getTerminator();
938 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
939 assert(hwOutputOp &&
"Expected terminator to be hw.output op");
942 SmallVector<Value> equalityChecks;
943 for (
auto &[originalRegValue, variableOp] : variableMap) {
944 opBuilder.setInsertionPointToStart(&newGuardBlock);
945 Value clonedRegValue = mapping.lookup(originalRegValue);
946 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
947 auto reg = cast<seq::CompRegOp>(clonedRegOp);
948 const auto res = variableOp.getResult();
949 clonedRegValue.replaceAllUsesWith(res);
952 for (
auto const &[originalRegValue, constStateValue] : toStateMap) {
954 Value clonedRegValue = mapping.lookup(originalRegValue);
955 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
956 opBuilder.setInsertionPoint(clonedRegOp);
957 auto r = cast<seq::CompRegOp>(clonedRegOp);
959 Value registerInput = r.getInput();
960 TypedValue<IntegerType> registerReset = r.getReset();
962 if (BlockArgument blockArg =
963 dyn_cast<BlockArgument>(registerReset)) {
965 opBuilder, blockArg.getLoc(), clonedRegValue.getType(), 0);
966 blockArg.replaceAllUsesWith(falseConst.getResult());
968 if (
auto xorOp = registerReset.getDefiningOp<
XorOp>()) {
969 if (xorOp.isBinaryNot()) {
970 Value rhs = xorOp.getOperand(0);
971 if (BlockArgument blockArg = dyn_cast<BlockArgument>(rhs)) {
973 opBuilder, blockArg.getLoc(), blockArg.getType(), 1);
974 blockArg.replaceAllUsesWith(trueConst.getResult());
979 Type constantType = registerInput.getType();
980 IntegerAttr constantAttr =
981 opBuilder.getIntegerAttr(constantType, constStateValue);
983 opBuilder, hwOutputOp.getLoc(), constantAttr);
986 ICmpOp::create(opBuilder, hwOutputOp.getLoc(), ICmpPredicate::eq,
987 registerInput, otherStateConstant.getResult());
988 equalityChecks.push_back(doesEqual.getResult());
990 opBuilder.setInsertionPoint(hwOutputOp);
991 auto allEqualCheck = AndOp::create(opBuilder, hwOutputOp.getLoc(),
992 equalityChecks,
false);
993 fsm::ReturnOp::create(opBuilder, hwOutputOp.getLoc(),
994 allEqualCheck.getResult());
996 for (BlockArgument arg : newGuardBlock.getArguments()) {
997 int argIndex = arg.getArgNumber();
998 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
999 arg.replaceAllUsesWith(topLevelArg);
1001 newGuardBlock.eraseArguments([](BlockArgument arg) {
return true; });
1003 intToRegMap(registers, currentStateIndex);
1004 for (
auto const &[originalRegValue, constStateValue] : fromStateMap) {
1005 Value clonedRegValue = mapping.lookup(originalRegValue);
1007 "Original register value not found in mapping");
1008 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
1009 assert(clonedRegOp &&
"Cloned value must have a defining op");
1010 opBuilder.setInsertionPoint(clonedRegOp);
1013 clonedRegValue.getType(), constStateValue);
1014 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
1015 clonedRegOp->erase();
1019 bool guardSorted = sortTopologically(&newGuardBlock);
1021 moduleOp.emitError()
1022 <<
"cannot convert module with combinational cycles to FSM";
1025 Region &actionRegion = transitionOp.getAction();
1026 if (!variableRegs.empty()) {
1027 Block *actionBlock = opBuilder.createBlock(&actionRegion);
1028 opBuilder.setInsertionPointToStart(actionBlock);
1030 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), actionRegion,
1031 actionBlock->getIterator(), mapping);
1032 actionBlock->erase();
1033 Block &newActionBlock = actionRegion.front();
1034 for (BlockArgument arg : newActionBlock.getArguments()) {
1035 int argIndex = arg.getArgNumber();
1036 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
1037 arg.replaceAllUsesWith(topLevelArg);
1039 newActionBlock.eraseArguments([](BlockArgument arg) {
return true; });
1040 for (
auto &[originalRegValue, variableOp] : variableMap) {
1041 Value clonedRegValue = mapping.lookup(originalRegValue);
1042 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
1043 auto reg = cast<seq::CompRegOp>(clonedRegOp);
1044 opBuilder.setInsertionPointToStart(&newActionBlock);
1045 UpdateOp::create(opBuilder,
reg.getLoc(), variableOp,
1047 const Value res = variableOp.getResult();
1048 clonedRegValue.replaceAllUsesWith(res);
1051 Operation *terminator = actionRegion.back().getTerminator();
1052 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
1053 assert(hwOutputOp &&
"Expected terminator to be hw.output op");
1056 for (
auto const &[originalRegValue, constStateValue] : fromStateMap) {
1057 Value clonedRegValue = mapping.lookup(originalRegValue);
1058 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
1059 opBuilder.setInsertionPoint(clonedRegOp);
1061 opBuilder, clonedRegValue.getLoc(), clonedRegValue.getType(),
1063 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
1064 clonedRegOp->erase();
1069 bool actionSorted = sortTopologically(&actionRegion.front());
1070 if (!actionSorted) {
1071 moduleOp.emitError()
1072 <<
"cannot convert module with combinational cycles to FSM";
1076 GreedyRewriteConfig config;
1077 SmallVector<Operation *> opsToProcess;
1078 actionRegion.walk([&](Operation *op) { opsToProcess.push_back(op); });
1079 config.setScope(&actionRegion);
1081 bool changed =
false;
1082 if (failed(applyOpPatternsGreedily(opsToProcess, frozenPatterns,
1087 SmallVector<Operation *> outputOps;
1088 stateOp.getOutput().walk(
1089 [&](Operation *op) { outputOps.push_back(op); });
1091 bool changed =
false;
1092 GreedyRewriteConfig config;
1093 config.setScope(&stateOp.getOutput());
1094 LogicalResult converged = applyOpPatternsGreedily(
1095 outputOps, frozenPatterns, config, &changed);
1096 assert(succeeded(converged) &&
"canonicalization failed to converge");
1097 SmallVector<Operation *> transitionOps;
1098 stateOp.getTransitions().walk(
1099 [&](Operation *op) { transitionOps.push_back(op); });
1101 GreedyRewriteConfig config2;
1102 config2.setScope(&stateOp.getTransitions());
1103 if (failed(applyOpPatternsGreedily(transitionOps, frozenPatterns,
1104 config2, &changed))) {
1113 simplifyActionWithGuard(transition, opBuilder);
1117 SmallVector<Operation *> postOps;
1118 stateOp.getTransitions().walk(
1119 [&](Operation *op) { postOps.push_back(op); });
1120 GreedyRewriteConfig postConfig;
1121 postConfig.setScope(&stateOp.getTransitions());
1122 if (failed(applyOpPatternsGreedily(postOps, frozenPatterns,
1123 postConfig, &changed)))
1129 StateOp nextState = transition.getNextStateOp();
1130 int nextStateIndex = stateOpToState.lookup(nextState);
1131 auto guardConst = transition.getGuardReturn()
1134 bool nextStateIsReachable =
1135 !guardConst || (guardConst.getValueAttr().getInt() != 0);
1138 if (nextStateIsReachable &&
1139 !reachableStates.contains(nextStateIndex)) {
1140 worklist.push_back(nextStateIndex);
1141 reachableStates.insert(nextStateIndex);
1150 SmallVector<StateOp> statesToErase;
1154 if (!stateOp.getOutputOp()) {
1155 statesToErase.push_back(stateOp);
1163 for (
StateOp stateOp : statesToErase) {
1165 if (transition.getNextStateOp().getSymName() == stateOp.getSymName()) {
1172 llvm::DenseSet<BlockArgument> asyncResetBlockArguments;
1173 for (
auto arg : machine.getBody().front().getArguments()) {
1174 if (asyncResetArguments.contains(arg.getArgNumber())) {
1175 asyncResetBlockArguments.insert(arg);
1183 if (!asyncResetBlockArguments.empty()) {
1184 moduleOp.emitWarning()
1185 <<
"reset signals detected and removed from FSM; "
1186 "reset behavior is captured only in the initial state";
1189 Block &front = machine.getBody().front();
1190 front.eraseArguments([&](BlockArgument arg) {
1191 return asyncResetBlockArguments.contains(arg);
1194 if (llvm::any_of(front.getArguments(), [](BlockArgument arg) {
1195 return arg.getType() == seq::ClockType::get(arg.getContext()) &&
1196 arg.hasNUsesOrMore(1);
1198 moduleOp.emitError(
"Clock uses outside register clocking are not "
1199 "currently supported.");
1202 machine.getBody().front().eraseArguments([&](BlockArgument arg) {
1203 return arg.getType() == seq::ClockType::get(arg.getContext());
1205 FunctionType oldFunctionType = machine.getFunctionType();
1206 SmallVector<Type> inputsWithoutClock;
1207 for (
unsigned int i = 0; i < oldFunctionType.getNumInputs(); i++) {
1208 Type input = oldFunctionType.getInput(i);
1209 if (input != seq::ClockType::get(input.getContext()) &&
1210 !asyncResetArguments.contains(i))
1211 inputsWithoutClock.push_back(input);
1214 FunctionType newFunctionType = FunctionType::get(
1215 opBuilder.getContext(), inputsWithoutClock, resultTypes);
1217 machine.setFunctionType(newFunctionType);
1225 auto regName =
reg.getName();
1231 if (!stateRegNames.empty()) {
1232 return llvm::is_contained(stateRegNames, regName->str());
1237 return regName->contains(
"state");
1241 OpBuilder &opBuilder;
1242 ArrayRef<std::string> stateRegNames;
1248struct CoreToFSMPass :
public circt::impl::ConvertCoreToFSMBase<CoreToFSMPass> {
1249 using ConvertCoreToFSMBase<CoreToFSMPass>::ConvertCoreToFSMBase;
1251 void runOnOperation()
override {
1252 auto module = getOperation();
1253 OpBuilder builder(module);
1255 SmallVector<HWModuleOp> modules;
1256 for (
auto hwModule : module.getOps<
HWModuleOp>()) {
1257 modules.push_back(hwModule);
1261 for (
auto hwModule : modules) {
1262 for (
auto instance : hwModule.getOps<
hw::InstanceOp>()) {
1263 instance.emitError() <<
"instance conversion is not yet supported";
1264 signalPassFailure();
1269 for (
auto hwModule : modules) {
1270 builder.setInsertionPoint(hwModule);
1271 HWModuleOpConverter converter(builder, hwModule, stateRegs);
1272 if (failed(converter.run())) {
1273 signalPassFailure();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
Direction
The direction of a Component or Cell port.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)