18#include "mlir/Analysis/TopologicalSortUtils.h"
19#include "mlir/IR/Block.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinAttributes.h"
22#include "mlir/IR/Diagnostics.h"
23#include "mlir/IR/IRMapping.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/IR/Value.h"
26#include "mlir/Pass/Pass.h"
27#include "mlir/Pass/PassManager.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#include "mlir/Transforms/Passes.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/DenseSet.h"
33#include "llvm/ADT/SmallVector.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/LogicalResult.h"
39#define GEN_PASS_DEF_CONVERTCORETOFSM
40#include "circt/Conversion/Passes.h.inc"
52static void generateConcatenatedValues(
53 const llvm::SmallVector<llvm::SetVector<size_t>> &allOperandValues,
54 const llvm::SmallVector<unsigned> &shifts,
55 llvm::SetVector<size_t> &finalPossibleValues);
58static void addPossibleValuesImpl(llvm::SetVector<size_t> &possibleValues,
59 Value v, llvm::DenseSet<Value> &visited) {
61 if (!visited.insert(v).second)
64 if (
auto c = dyn_cast_or_null<hw::ConstantOp>(v.getDefiningOp())) {
65 possibleValues.insert(c.getValueAttr().getValue().getZExtValue());
68 if (
auto m = dyn_cast_or_null<MuxOp>(v.getDefiningOp())) {
69 addPossibleValuesImpl(possibleValues, m.getTrueValue(), visited);
70 addPossibleValuesImpl(possibleValues, m.getFalseValue(), visited);
74 if (
auto concatOp = dyn_cast_or_null<ConcatOp>(v.getDefiningOp())) {
75 llvm::SmallVector<llvm::SetVector<size_t>> allOperandValues;
76 llvm::SmallVector<unsigned> operandWidths;
78 for (Value operand : concatOp.getOperands()) {
79 llvm::SetVector<size_t> operandPossibleValues;
80 addPossibleValuesImpl(operandPossibleValues, operand, visited);
85 auto opType = dyn_cast<IntegerType>(operand.getType());
88 assert(opType &&
"comb.concat operand must be an integer type");
89 unsigned width = opType.getWidth();
90 if (operandPossibleValues.empty()) {
91 uint64_t numStates = 1ULL << width;
94 if (numStates > 256) {
98 v.getDefiningOp()->emitWarning()
99 <<
"Search space too large (>" << 256
100 <<
" states) for operand with bitwidth " << width
101 <<
"; abandoning analysis for this path";
104 for (uint64_t i = 0; i < numStates; ++i)
105 operandPossibleValues.insert(i);
108 allOperandValues.push_back(operandPossibleValues);
109 operandWidths.push_back(width);
114 llvm::SmallVector<unsigned> shifts(concatOp.getNumOperands(), 0);
115 for (
int i = concatOp.getNumOperands() - 2; i >= 0; --i) {
116 shifts[i] = shifts[i + 1] + operandWidths[i + 1];
119 generateConcatenatedValues(allOperandValues, shifts, possibleValues);
127 auto addrType = dyn_cast<IntegerType>(v.getType());
131 unsigned bitWidth = addrType.getWidth();
134 if (v.getDefiningOp())
135 v.getDefiningOp()->emitWarning()
136 <<
"Bitwidth " << bitWidth
137 <<
" too large (>16); abandoning analysis for this path";
141 uint64_t numRegStates = 1ULL << bitWidth;
142 for (
size_t i = 0; i < numRegStates; i++) {
143 possibleValues.insert(i);
147static void addPossibleValues(llvm::SetVector<size_t> &possibleValues,
149 llvm::DenseSet<Value> visited;
150 addPossibleValuesImpl(possibleValues, v, visited);
158static bool areStructurallyEquivalent(Value a, Value b, Region ®ionA,
163 Operation *opA =
a.getDefiningOp();
164 Operation *opB =
b.getDefiningOp();
168 bool aIsLocal = regionA.isAncestor(opA->getParentRegion());
169 bool bIsLocal = regionB.isAncestor(opB->getParentRegion());
170 if (aIsLocal != bIsLocal)
177 if (cast<OpResult>(a).getResultNumber() !=
178 cast<OpResult>(b).getResultNumber())
181 return OperationEquivalence::isEquivalentTo(
183 [&](Value lhs, Value rhs) -> LogicalResult {
184 return success(areStructurallyEquivalent(lhs, rhs, regionA, regionB));
186 nullptr, OperationEquivalence::Flags::IgnoreLocations);
192class GuardConditionFoldPattern :
public RewritePattern {
194 GuardConditionFoldPattern(MLIRContext *ctx,
195 ArrayRef<std::pair<Value, bool>> guardFacts,
196 Region &guardRegion, Region &actionRegion)
197 : RewritePattern(MatchAnyOpTypeTag(), 10, ctx),
198 guardFacts(guardFacts.begin(), guardFacts.
end()),
199 guardRegion(guardRegion), actionRegion(actionRegion) {}
201 LogicalResult matchAndRewrite(Operation *op,
202 PatternRewriter &rewriter)
const override {
203 if (!actionRegion.isAncestor(op->getParentRegion()))
208 if (isa<hw::ConstantOp>(op))
211 for (Value result : op->getResults()) {
212 if (!result.getType().isInteger(1) || result.use_empty())
215 for (
auto [guardExpr, isTrue] : guardFacts) {
217 if (areStructurallyEquivalent(guardExpr, result, guardRegion,
219 return replaceWithConstant(result, isTrue, op, rewriter);
223 if (
auto guardIcmp = guardExpr.getDefiningOp<ICmpOp>()) {
224 if (
auto actionIcmp = dyn_cast<ICmpOp>(op)) {
225 if (actionIcmp.getPredicate() ==
226 ICmpOp::getNegatedPredicate(guardIcmp.getPredicate()) &&
227 areStructurallyEquivalent(guardIcmp.getLhs(),
228 actionIcmp.getLhs(), guardRegion,
230 areStructurallyEquivalent(guardIcmp.getRhs(),
231 actionIcmp.getRhs(), guardRegion,
233 return replaceWithConstant(result, !isTrue, op, rewriter);
242 LogicalResult replaceWithConstant(Value result,
bool constVal, Operation *op,
243 PatternRewriter &rewriter)
const {
244 rewriter.setInsertionPointToStart(&actionRegion.front());
246 rewriter.getI1Type(), constVal ? 1 : 0);
247 rewriter.replaceAllUsesWith(result, c);
249 rewriter.eraseOp(op);
253 SmallVector<std::pair<Value, bool>> guardFacts;
255 Region &actionRegion;
264static void simplifyActionWithGuard(
TransitionOp transition,
265 OpBuilder &builder) {
266 Region &guardRegion = transition.getGuard();
267 Region &actionRegion = transition.getAction();
269 if (guardRegion.empty() || actionRegion.empty())
273 dyn_cast<fsm::ReturnOp>(guardRegion.front().getTerminator());
277 Location loc = guardReturn.getLoc();
278 Value guardCondition = guardReturn.getOperand();
281 SmallVector<std::pair<Value, bool>> guardFacts;
284 SmallVector<std::pair<Value, bool>> worklist;
285 worklist.push_back({guardCondition,
true});
287 while (!worklist.empty()) {
288 auto [cond, isTrue] = worklist.pop_back_val();
289 guardFacts.push_back({cond, isTrue});
292 if (
auto xorOp = cond.getDefiningOp<
XorOp>()) {
293 if (xorOp.isBinaryNot() &&
294 guardRegion.isAncestor(xorOp->getParentRegion()))
295 guardFacts.push_back({xorOp.getOperand(0), !isTrue});
300 if (
auto andOp = cond.getDefiningOp<
AndOp>()) {
301 if (guardRegion.isAncestor(andOp->getParentRegion())) {
302 for (Value operand : andOp.getOperands())
303 worklist.push_back({operand,
true});
311 for (
auto [guardExpr, isTrue] : guardFacts) {
312 bool guardIsExternal =
313 !guardExpr.getDefiningOp() ||
314 !guardRegion.isAncestor(guardExpr.getDefiningOp()->getParentRegion());
315 if (guardIsExternal) {
316 builder.setInsertionPointToStart(&actionRegion.front());
319 guardExpr.replaceUsesWithIf(constOp.getResult(), [&](OpOperand &use) {
320 if (!actionRegion.isAncestor(use.getOwner()->getParentRegion()))
324 if (auto updateOp = dyn_cast<UpdateOp>(use.getOwner()))
325 if (use.getOperandNumber() == 0)
333 MLIRContext *ctx = transition.getContext();
335 patterns.add<GuardConditionFoldPattern>(ctx, guardFacts, guardRegion,
337 SmallVector<Operation *> actionOps;
338 actionRegion.walk([&](Operation *op) { actionOps.push_back(op); });
339 GreedyRewriteConfig config;
340 config.setScope(&actionRegion);
341 (void)applyOpPatternsGreedily(
342 actionOps, FrozenRewritePatternSet(std::move(
patterns)), config);
347static bool isConstantOrConstantTree(Value value) {
348 SmallVector<Value> worklist;
349 llvm::DenseSet<Value> visited;
351 worklist.push_back(value);
352 while (!worklist.empty()) {
353 Value current = worklist.pop_back_val();
356 if (!visited.insert(current).second)
359 Operation *definingOp = current.getDefiningOp();
363 if (isa<hw::ConstantOp>(definingOp))
366 if (
auto muxOp = dyn_cast<MuxOp>(definingOp)) {
367 worklist.push_back(muxOp.getTrueValue());
368 worklist.push_back(muxOp.getFalseValue());
383LogicalResult pushIcmp(ICmpOp op, PatternRewriter &rewriter) {
385 if (op.getPredicate() == ICmpPredicate::eq &&
386 op.getLhs().getDefiningOp<
MuxOp>() &&
387 (isConstantOrConstantTree(op.getLhs()) ||
389 rewriter.setInsertionPointAfter(op);
390 auto mux = op.getLhs().getDefiningOp<
MuxOp>();
391 Value x = mux.getTrueValue();
392 Value y = mux.getFalseValue();
393 Value
b = op.getRhs();
394 Location loc = op.getLoc();
395 auto eq1 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, x, b);
396 auto eq2 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, y, b);
397 rewriter.replaceOpWithNewOp<
MuxOp>(op, mux.getCond(), eq1.getResult(),
399 return llvm::success();
401 if (op.getPredicate() == ICmpPredicate::eq &&
402 op.getRhs().getDefiningOp<
MuxOp>() &&
403 (isConstantOrConstantTree(op.getRhs()) ||
405 rewriter.setInsertionPointAfter(op);
406 auto mux = op.getRhs().getDefiningOp<
MuxOp>();
407 Value x = mux.getTrueValue();
408 Value y = mux.getFalseValue();
409 Value
b = op.getLhs();
410 Location loc = op.getLoc();
411 auto eq1 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, x, b);
412 auto eq2 = ICmpOp::create(rewriter, loc, ICmpPredicate::eq, y, b);
413 rewriter.replaceOpWithNewOp<
MuxOp>(op, mux.getCond(), eq1.getResult(),
415 return llvm::success();
417 return llvm::failure();
422static void generateConcatenatedValues(
423 const llvm::SmallVector<llvm::SetVector<size_t>> &allOperandValues,
424 const llvm::SmallVector<unsigned> &shifts,
425 llvm::SetVector<size_t> &finalPossibleValues) {
427 if (allOperandValues.empty()) {
428 finalPossibleValues.insert(0);
433 llvm::SetVector<size_t> currentResults;
434 for (
size_t val : allOperandValues[0])
435 currentResults.insert(val << shifts[0]);
438 for (
size_t operandIdx = 1; operandIdx < allOperandValues.size();
440 llvm::SetVector<size_t> nextResults;
441 unsigned shift = shifts[operandIdx];
443 for (
size_t partialValue : currentResults) {
444 for (
size_t val : allOperandValues[operandIdx]) {
445 nextResults.insert(partialValue | (val << shift));
448 currentResults = std::move(nextResults);
451 for (
size_t val : currentResults)
452 finalPossibleValues.insert(val);
455static llvm::MapVector<Value, int> intToRegMap(SmallVector<seq::CompRegOp> v,
457 llvm::MapVector<Value, int> m;
458 for (
size_t ci = 0; ci < v.size(); ci++) {
460 int bits =
reg.getType().getIntOrFloatBitWidth();
461 int v = i & ((1 << bits) - 1);
468static int regMapToInt(SmallVector<seq::CompRegOp> v,
469 llvm::DenseMap<Value, int> m) {
472 for (
size_t ci = 0; ci < v.size(); ci++) {
474 i += m[
reg] * 1ULL << width;
475 width += (
reg.getType().getIntOrFloatBitWidth());
481static std::set<llvm::SmallVector<size_t>> calculateCartesianProduct(
482 const llvm::SmallVector<llvm::SetVector<size_t>> &valueSets) {
483 std::set<llvm::SmallVector<size_t>> product;
484 if (valueSets.empty()) {
493 for (
size_t value : valueSets.front()) {
494 product.insert({value});
500 for (
size_t i = 1; i < valueSets.size(); ++i) {
501 const auto ¤tSet = valueSets[i];
502 if (currentSet.empty()) {
507 std::set<llvm::SmallVector<size_t>> newProduct;
508 for (
const auto &existingVector : product) {
509 for (
size_t newValue : currentSet) {
510 llvm::SmallVector<size_t> newVector = existingVector;
511 newVector.push_back(newValue);
512 newProduct.insert(std::move(newVector));
515 product = std::move(newProduct);
521static FrozenRewritePatternSet loadPatterns(MLIRContext &
context) {
524 for (
auto *dialect :
context.getLoadedDialects())
525 dialect->getCanonicalizationPatterns(
patterns);
540 FrozenRewritePatternSet frozenPatterns(std::move(
patterns));
541 return frozenPatterns;
545getReachableStates(llvm::SetVector<size_t> &visitableStates,
546 HWModuleOp moduleOp,
size_t currentStateIndex,
547 SmallVector<seq::CompRegOp> registers, OpBuilder opBuilder,
548 bool isInitialState) {
552 llvm::dyn_cast<HWModuleOp>(opBuilder.clone(*moduleOp, mapping));
554 llvm::MapVector<Value, int> stateMap =
555 intToRegMap(registers, currentStateIndex);
556 Operation *terminator = clonedBody.getBody().front().getTerminator();
557 auto output = dyn_cast<hw::OutputOp>(terminator);
558 SmallVector<Value> values;
560 for (
auto [originalRegValue, constStateValue] : stateMap) {
562 Value clonedRegValue = mapping.lookup(originalRegValue);
563 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
564 auto reg = cast<seq::CompRegOp>(clonedRegOp);
565 Type constantType =
reg.getType();
566 IntegerAttr constantAttr =
567 opBuilder.getIntegerAttr(constantType, constStateValue);
568 opBuilder.setInsertionPoint(clonedRegOp);
569 auto otherStateConstant =
574 Value regInput =
reg.getInput();
575 if (regInput == clonedRegValue)
576 values.push_back(otherStateConstant.getResult());
578 values.push_back(regInput);
579 clonedRegValue.replaceAllUsesWith(otherStateConstant.getResult());
582 opBuilder.setInsertionPointToEnd(clonedBody.front().getBlock());
583 auto newOutput = hw::OutputOp::create(opBuilder, output.getLoc(), values);
585 FrozenRewritePatternSet frozenPatterns = loadPatterns(*moduleOp.getContext());
587 SmallVector<Operation *> opsToProcess;
588 clonedBody.walk([&](Operation *op) { opsToProcess.push_back(op); });
590 bool changed =
false;
591 GreedyRewriteConfig config;
592 if (failed(applyOpPatternsGreedily(opsToProcess, frozenPatterns, config,
596 llvm::SmallVector<llvm::SetVector<size_t>> pv;
597 for (
size_t j = 0; j < newOutput.getNumOperands(); j++) {
598 llvm::SetVector<size_t> possibleValues;
600 Value v = newOutput.getOperand(j);
601 addPossibleValues(possibleValues, v);
602 pv.push_back(possibleValues);
604 std::set<llvm::SmallVector<size_t>> flipped = calculateCartesianProduct(pv);
605 for (llvm::SmallVector<size_t> v : flipped) {
606 llvm::DenseMap<Value, int> m;
607 for (
size_t k = 0; k < v.size(); k++) {
612 int i = regMapToInt(registers, m);
613 visitableStates.insert(i);
621class HWModuleOpConverter {
623 HWModuleOpConverter(OpBuilder &builder,
HWModuleOp moduleOp,
624 ArrayRef<std::string> stateRegNames)
625 : moduleOp(moduleOp), opBuilder(builder), stateRegNames(stateRegNames) {}
626 LogicalResult
run() {
627 SmallVector<seq::CompRegOp> stateRegs;
628 SmallVector<seq::CompRegOp> variableRegs;
629 Value foundClock, foundReset =
nullptr;
632 auto reset =
reg.getReset();
634 if (clk != foundClock) {
635 reg.emitError(
"All registers must have the same clock signal.");
636 return WalkResult::interrupt();
644 if (reset != foundReset) {
645 reg.emitError(
"All registers must have the same reset signal.");
646 return WalkResult::interrupt();
654 if (!isa<IntegerType>(
reg.getType())) {
655 reg.emitError(
"FSM extraction only supports integer-typed registers");
656 return WalkResult::interrupt();
658 if (isStateRegister(reg)) {
659 stateRegs.push_back(reg);
661 variableRegs.push_back(reg);
663 return WalkResult::advance();
665 if (walkResult.wasInterrupted())
667 if (stateRegs.empty()) {
668 emitError(moduleOp.getLoc())
669 <<
"Cannot find state register in this FSM. Use the state-regs "
670 "option to specify which registers are state registers.";
673 SmallVector<seq::CompRegOp> registers;
675 registers.push_back(c);
678 llvm::DenseMap<size_t, StateOp> stateToStateOp;
679 llvm::DenseMap<StateOp, size_t> stateOpToState;
684 llvm::DenseSet<size_t> asyncResetArguments;
685 Location loc = moduleOp.getLoc();
686 SmallVector<Type> inputTypes = moduleOp.getInputTypes();
689 auto resultTypes = moduleOp.getOutputTypes();
690 FunctionType machineType =
691 FunctionType::get(opBuilder.getContext(), inputTypes, resultTypes);
692 StringRef machineName = moduleOp.getName();
694 llvm::DenseMap<Value, int> initialStateMap;
696 Value resetValue =
reg.getResetValue();
702 reg.emitWarning(
"Assuming register with no reset starts with value 0");
706 if (!definingConstant) {
708 "cannot find defining constant for reset value of register");
712 definingConstant.getValueAttr().getValue().getZExtValue();
713 initialStateMap[
reg] = resetValueInt;
715 int initialStateIndex = regMapToInt(registers, initialStateMap);
717 std::string initialStateName =
"state_" + std::to_string(initialStateIndex);
720 SmallVector<NamedAttribute> machineAttrs;
721 if (
auto argNames = moduleOp->getAttrOfType<ArrayAttr>(
"argNames"))
722 machineAttrs.emplace_back(opBuilder.getStringAttr(
"argNames"), argNames);
723 if (
auto resNames = moduleOp->getAttrOfType<ArrayAttr>(
"resultNames"))
724 machineAttrs.emplace_back(opBuilder.getStringAttr(
"resNames"), resNames);
728 opBuilder.setInsertionPoint(moduleOp);
730 MachineOp::create(opBuilder, loc, machineName, initialStateName,
731 machineType, machineAttrs);
733 OpBuilder::InsertionGuard guard(opBuilder);
734 opBuilder.setInsertionPointToStart(&machine.getBody().front());
735 llvm::MapVector<seq::CompRegOp, VariableOp> variableMap;
737 TypedValue<Type> initialValue = varReg.getResetValue();
744 "Assuming register with no reset starts with value 0");
746 varReg.getType(), 0);
748 if (!definingConstant) {
749 varReg->emitError(
"cannot find defining constant for reset value of "
750 "variable register");
753 auto variableOp = VariableOp::create(
754 opBuilder, varReg->getLoc(), varReg.getInput().getType(),
755 definingConstant.getValueAttr(), varReg.getName().value_or(
"var"));
756 variableMap[varReg] = variableOp;
760 FrozenRewritePatternSet frozenPatterns =
761 loadPatterns(*moduleOp.getContext());
763 SetVector<int> reachableStates;
764 SmallVector<int> worklist;
766 worklist.push_back(initialStateIndex);
767 reachableStates.insert(initialStateIndex);
770 for (
unsigned i = 0; i < worklist.size(); ++i) {
772 int currentStateIndex = worklist[i];
774 llvm::MapVector<Value, int> stateMap =
775 intToRegMap(registers, currentStateIndex);
777 opBuilder.setInsertionPointToEnd(&machine.getBody().front());
781 if (!stateToStateOp.contains(currentStateIndex)) {
783 "state_" + std::to_string(currentStateIndex));
784 stateToStateOp.insert({currentStateIndex, stateOp});
785 stateOpToState.insert({stateOp, currentStateIndex});
787 stateOp = stateToStateOp.lookup(currentStateIndex);
789 Region &outputRegion = stateOp.getOutput();
790 Block *outputBlock = &outputRegion.front();
791 opBuilder.setInsertionPointToStart(outputBlock);
793 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), outputRegion,
794 outputBlock->getIterator(), mapping);
795 outputBlock->erase();
797 auto *terminator = outputRegion.front().getTerminator();
798 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
799 assert(hwOutputOp &&
"Expected terminator to be hw.output op");
803 OpBuilder::InsertionGuard stateGuard(opBuilder);
804 opBuilder.setInsertionPoint(hwOutputOp);
809 hwOutputOp.getOperands());
816 for (
auto &[originalRegValue, variableOp] : variableMap) {
817 Value clonedRegValue = mapping.lookup(originalRegValue);
818 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
819 auto reg = cast<seq::CompRegOp>(clonedRegOp);
820 const auto res = variableOp.getResult();
821 clonedRegValue.replaceAllUsesWith(res);
824 for (
auto const &[originalRegValue, constStateValue] : stateMap) {
826 Value clonedRegValue = mapping.lookup(originalRegValue);
828 "Original register value not found in mapping");
829 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
831 assert(clonedRegOp &&
"Cloned value must have a defining op");
832 opBuilder.setInsertionPoint(clonedRegOp);
833 auto r = cast<seq::CompRegOp>(clonedRegOp);
834 TypedValue<IntegerType> registerReset = r.getReset();
836 if (BlockArgument blockArg = dyn_cast<BlockArgument>(registerReset)) {
837 asyncResetArguments.insert(blockArg.getArgNumber());
839 opBuilder, blockArg.getLoc(), clonedRegValue.getType(), 0);
840 blockArg.replaceAllUsesWith(falseConst.getResult());
842 if (
auto xorOp = registerReset.getDefiningOp<
XorOp>()) {
843 if (xorOp.isBinaryNot()) {
844 Value rhs = xorOp.getOperand(0);
845 if (BlockArgument blockArg = dyn_cast<BlockArgument>(rhs)) {
846 asyncResetArguments.insert(blockArg.getArgNumber());
848 opBuilder, blockArg.getLoc(), blockArg.getType(), 1);
849 blockArg.replaceAllUsesWith(trueConst.getResult());
856 clonedRegValue.getType(), constStateValue);
857 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
858 clonedRegOp->erase();
860 GreedyRewriteConfig config;
861 SmallVector<Operation *> opsToProcess;
862 outputRegion.walk([&](Operation *op) { opsToProcess.push_back(op); });
865 for (
auto arg : outputRegion.front().getArguments()) {
866 int argIndex = arg.getArgNumber();
867 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
868 arg.replaceAllUsesWith(topLevelArg);
870 outputRegion.front().eraseArguments(
871 [](BlockArgument arg) {
return true; });
872 FrozenRewritePatternSet
patterns(opBuilder.getContext());
873 config.setScope(&outputRegion);
875 bool changed =
false;
876 if (failed(applyOpPatternsGreedily(opsToProcess,
patterns, config,
879 opBuilder.setInsertionPoint(stateOp);
884 bool sorted = sortTopologically(&outputRegion.front());
887 <<
"cannot convert module with combinational cycles to FSM";
890 Region &transitionRegion = stateOp.getTransitions();
891 llvm::SetVector<size_t> visitableStates;
892 if (failed(getReachableStates(visitableStates, moduleOp,
893 currentStateIndex, registers, opBuilder,
894 currentStateIndex == initialStateIndex)))
896 for (
size_t j : visitableStates) {
898 if (!stateToStateOp.contains(j)) {
899 opBuilder.setInsertionPointToEnd(&machine.getBody().front());
902 stateToStateOp.insert({j, toState});
903 stateOpToState.insert({toState, j});
905 toState = stateToStateOp[j];
907 opBuilder.setInsertionPointToStart(&transitionRegion.front());
910 Region &guardRegion = transitionOp.getGuard();
911 opBuilder.createBlock(&guardRegion);
913 Block &guardBlock = guardRegion.front();
915 opBuilder.setInsertionPointToStart(&guardBlock);
917 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), guardRegion,
918 guardBlock.getIterator(), mapping);
920 Block &newGuardBlock = guardRegion.front();
921 Operation *terminator = newGuardBlock.getTerminator();
922 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
923 assert(hwOutputOp &&
"Expected terminator to be hw.output op");
925 llvm::MapVector<Value, int> toStateMap = intToRegMap(registers, j);
926 SmallVector<Value> equalityChecks;
927 for (
auto &[originalRegValue, variableOp] : variableMap) {
928 opBuilder.setInsertionPointToStart(&newGuardBlock);
929 Value clonedRegValue = mapping.lookup(originalRegValue);
930 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
931 auto reg = cast<seq::CompRegOp>(clonedRegOp);
932 const auto res = variableOp.getResult();
933 clonedRegValue.replaceAllUsesWith(res);
936 for (
auto const &[originalRegValue, constStateValue] : toStateMap) {
938 Value clonedRegValue = mapping.lookup(originalRegValue);
939 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
940 opBuilder.setInsertionPoint(clonedRegOp);
941 auto r = cast<seq::CompRegOp>(clonedRegOp);
943 Value registerInput = r.getInput();
944 TypedValue<IntegerType> registerReset = r.getReset();
946 if (BlockArgument blockArg =
947 dyn_cast<BlockArgument>(registerReset)) {
949 opBuilder, blockArg.getLoc(), clonedRegValue.getType(), 0);
950 blockArg.replaceAllUsesWith(falseConst.getResult());
952 if (
auto xorOp = registerReset.getDefiningOp<
XorOp>()) {
953 if (xorOp.isBinaryNot()) {
954 Value rhs = xorOp.getOperand(0);
955 if (BlockArgument blockArg = dyn_cast<BlockArgument>(rhs)) {
957 opBuilder, blockArg.getLoc(), blockArg.getType(), 1);
958 blockArg.replaceAllUsesWith(trueConst.getResult());
963 Type constantType = registerInput.getType();
964 IntegerAttr constantAttr =
965 opBuilder.getIntegerAttr(constantType, constStateValue);
967 opBuilder, hwOutputOp.getLoc(), constantAttr);
970 ICmpOp::create(opBuilder, hwOutputOp.getLoc(), ICmpPredicate::eq,
971 registerInput, otherStateConstant.getResult());
972 equalityChecks.push_back(doesEqual.getResult());
974 opBuilder.setInsertionPoint(hwOutputOp);
975 auto allEqualCheck = AndOp::create(opBuilder, hwOutputOp.getLoc(),
976 equalityChecks,
false);
977 fsm::ReturnOp::create(opBuilder, hwOutputOp.getLoc(),
978 allEqualCheck.getResult());
980 for (BlockArgument arg : newGuardBlock.getArguments()) {
981 int argIndex = arg.getArgNumber();
982 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
983 arg.replaceAllUsesWith(topLevelArg);
985 newGuardBlock.eraseArguments([](BlockArgument arg) {
return true; });
986 llvm::MapVector<Value, int> fromStateMap =
987 intToRegMap(registers, currentStateIndex);
988 for (
auto const &[originalRegValue, constStateValue] : fromStateMap) {
989 Value clonedRegValue = mapping.lookup(originalRegValue);
991 "Original register value not found in mapping");
992 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
993 assert(clonedRegOp &&
"Cloned value must have a defining op");
994 opBuilder.setInsertionPoint(clonedRegOp);
997 clonedRegValue.getType(), constStateValue);
998 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
999 clonedRegOp->erase();
1001 Region &actionRegion = transitionOp.getAction();
1002 if (!variableRegs.empty()) {
1003 Block *actionBlock = opBuilder.createBlock(&actionRegion);
1004 opBuilder.setInsertionPointToStart(actionBlock);
1006 opBuilder.cloneRegionBefore(moduleOp.getModuleBody(), actionRegion,
1007 actionBlock->getIterator(), mapping);
1008 actionBlock->erase();
1009 Block &newActionBlock = actionRegion.front();
1010 for (BlockArgument arg : newActionBlock.getArguments()) {
1011 int argIndex = arg.getArgNumber();
1012 BlockArgument topLevelArg = machine.getBody().getArgument(argIndex);
1013 arg.replaceAllUsesWith(topLevelArg);
1015 newActionBlock.eraseArguments([](BlockArgument arg) {
return true; });
1016 for (
auto &[originalRegValue, variableOp] : variableMap) {
1017 Value clonedRegValue = mapping.lookup(originalRegValue);
1018 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
1019 auto reg = cast<seq::CompRegOp>(clonedRegOp);
1020 opBuilder.setInsertionPointToStart(&newActionBlock);
1021 UpdateOp::create(opBuilder,
reg.getLoc(), variableOp,
1023 const Value res = variableOp.getResult();
1024 clonedRegValue.replaceAllUsesWith(res);
1027 Operation *terminator = actionRegion.back().getTerminator();
1028 auto hwOutputOp = dyn_cast<hw::OutputOp>(terminator);
1029 assert(hwOutputOp &&
"Expected terminator to be hw.output op");
1032 for (
auto const &[originalRegValue, constStateValue] : fromStateMap) {
1033 Value clonedRegValue = mapping.lookup(originalRegValue);
1034 Operation *clonedRegOp = clonedRegValue.getDefiningOp();
1035 opBuilder.setInsertionPoint(clonedRegOp);
1037 opBuilder, clonedRegValue.getLoc(), clonedRegValue.getType(),
1039 clonedRegValue.replaceAllUsesWith(constantOp.getResult());
1040 clonedRegOp->erase();
1043 GreedyRewriteConfig config;
1044 SmallVector<Operation *> opsToProcess;
1045 actionRegion.walk([&](Operation *op) { opsToProcess.push_back(op); });
1046 config.setScope(&actionRegion);
1048 bool changed =
false;
1049 if (failed(applyOpPatternsGreedily(opsToProcess, frozenPatterns,
1056 bool actionSorted = sortTopologically(&actionRegion.front());
1057 if (!actionSorted) {
1058 moduleOp.emitError()
1059 <<
"cannot convert module with combinational cycles to FSM";
1067 bool guardSorted = sortTopologically(&newGuardBlock);
1069 moduleOp.emitError()
1070 <<
"cannot convert module with combinational cycles to FSM";
1073 SmallVector<Operation *> outputOps;
1074 stateOp.getOutput().walk(
1075 [&](Operation *op) { outputOps.push_back(op); });
1077 bool changed =
false;
1078 GreedyRewriteConfig config;
1079 config.setScope(&stateOp.getOutput());
1080 LogicalResult converged = applyOpPatternsGreedily(
1081 outputOps, frozenPatterns, config, &changed);
1082 assert(succeeded(converged) &&
"canonicalization failed to converge");
1083 SmallVector<Operation *> transitionOps;
1084 stateOp.getTransitions().walk(
1085 [&](Operation *op) { transitionOps.push_back(op); });
1087 GreedyRewriteConfig config2;
1088 config2.setScope(&stateOp.getTransitions());
1089 if (failed(applyOpPatternsGreedily(transitionOps, frozenPatterns,
1090 config2, &changed))) {
1099 simplifyActionWithGuard(transition, opBuilder);
1103 SmallVector<Operation *> postOps;
1104 stateOp.getTransitions().walk(
1105 [&](Operation *op) { postOps.push_back(op); });
1106 GreedyRewriteConfig postConfig;
1107 postConfig.setScope(&stateOp.getTransitions());
1108 if (failed(applyOpPatternsGreedily(postOps, frozenPatterns,
1109 postConfig, &changed)))
1115 StateOp nextState = transition.getNextStateOp();
1116 int nextStateIndex = stateOpToState.lookup(nextState);
1117 auto guardConst = transition.getGuardReturn()
1120 bool nextStateIsReachable =
1121 !guardConst || (guardConst.getValueAttr().getInt() != 0);
1124 if (nextStateIsReachable &&
1125 !reachableStates.contains(nextStateIndex)) {
1126 worklist.push_back(nextStateIndex);
1127 reachableStates.insert(nextStateIndex);
1136 SmallVector<StateOp> statesToErase;
1140 if (!stateOp.getOutputOp()) {
1141 statesToErase.push_back(stateOp);
1149 for (
StateOp stateOp : statesToErase) {
1151 if (transition.getNextStateOp().getSymName() == stateOp.getSymName()) {
1158 llvm::DenseSet<BlockArgument> asyncResetBlockArguments;
1159 for (
auto arg : machine.getBody().front().getArguments()) {
1160 if (asyncResetArguments.contains(arg.getArgNumber())) {
1161 asyncResetBlockArguments.insert(arg);
1169 if (!asyncResetBlockArguments.empty()) {
1170 moduleOp.emitWarning()
1171 <<
"reset signals detected and removed from FSM; "
1172 "reset behavior is captured only in the initial state";
1175 Block &front = machine.getBody().front();
1176 front.eraseArguments([&](BlockArgument arg) {
1177 return asyncResetBlockArguments.contains(arg);
1180 if (llvm::any_of(front.getArguments(), [](BlockArgument arg) {
1181 return arg.getType() == seq::ClockType::get(arg.getContext()) &&
1182 arg.hasNUsesOrMore(1);
1184 moduleOp.emitError(
"Clock uses outside register clocking are not "
1185 "currently supported.");
1188 machine.getBody().front().eraseArguments([&](BlockArgument arg) {
1189 return arg.getType() == seq::ClockType::get(arg.getContext());
1191 FunctionType oldFunctionType = machine.getFunctionType();
1192 SmallVector<Type> inputsWithoutClock;
1193 for (
unsigned int i = 0; i < oldFunctionType.getNumInputs(); i++) {
1194 Type input = oldFunctionType.getInput(i);
1195 if (input != seq::ClockType::get(input.getContext()) &&
1196 !asyncResetArguments.contains(i))
1197 inputsWithoutClock.push_back(input);
1200 FunctionType newFunctionType = FunctionType::get(
1201 opBuilder.getContext(), inputsWithoutClock, resultTypes);
1203 machine.setFunctionType(newFunctionType);
1211 auto regName =
reg.getName();
1217 if (!stateRegNames.empty()) {
1218 return llvm::is_contained(stateRegNames, regName->str());
1223 return regName->contains(
"state");
1227 OpBuilder &opBuilder;
1228 ArrayRef<std::string> stateRegNames;
1234struct CoreToFSMPass :
public circt::impl::ConvertCoreToFSMBase<CoreToFSMPass> {
1235 using ConvertCoreToFSMBase<CoreToFSMPass>::ConvertCoreToFSMBase;
1237 void runOnOperation()
override {
1238 auto module = getOperation();
1239 OpBuilder builder(module);
1241 SmallVector<HWModuleOp> modules;
1242 for (
auto hwModule : module.getOps<
HWModuleOp>()) {
1243 modules.push_back(hwModule);
1247 for (
auto hwModule : modules) {
1248 for (
auto instance : hwModule.getOps<
hw::InstanceOp>()) {
1249 instance.emitError() <<
"instance conversion is not yet supported";
1250 signalPassFailure();
1255 for (
auto hwModule : modules) {
1256 builder.setInsertionPoint(hwModule);
1257 HWModuleOpConverter converter(builder, hwModule, stateRegs);
1258 if (failed(converter.run())) {
1259 signalPassFailure();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)