17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/RegionUtils.h"
19#include "llvm/ADT/TypeSwitch.h"
25#define GEN_PASS_DEF_CONVERTFSMTOSV
26#include "circt/Conversion/Passes.h.inc"
43 machine.getHWPortInfo(ports);
44 ClkRstIdxs specialPorts;
48 clock.
name = b.getStringAttr(
"clk");
49 clock.
dir = hw::ModulePort::Direction::Input;
50 clock.
type = seq::ClockType::get(b.getContext());
51 clock.
argNum = machine.getNumArguments();
52 ports.push_back(clock);
53 specialPorts.clockIdx = clock.
argNum;
57 reset.
name = b.getStringAttr(
"rst");
58 reset.
dir = hw::ModulePort::Direction::Input;
59 reset.
type = b.getI1Type();
60 reset.
argNum = machine.getNumArguments() + 1;
61 ports.push_back(reset);
62 specialPorts.resetIdx = reset.
argNum;
70 llvm::SetVector<Value> captures;
71 getUsedValuesDefinedAbove(region, region, captures);
73 OpBuilder::InsertionGuard guard(builder);
74 builder.setInsertionPointToStart(®ion.front());
77 for (
auto &capture : captures) {
78 Operation *op = capture.getDefiningOp();
79 if (!op || !op->hasTrait<OpTrait::ConstantLike>())
82 Operation *cloned = builder.clone(*op);
83 for (
auto [orig, replacement] :
84 llvm::zip(op->getResults(), cloned->getResults()))
85 replaceAllUsesInRegionWith(orig, replacement, region);
106 Type getStateType() {
return stateType; }
109 std::unique_ptr<sv::CasePattern> getCasePattern(
StateOp state);
118 void setEncoding(
StateOp state, Value v,
bool wire =
false);
142 : typeScope(typeScope), b(b), machine(machine), hwModule(hwModule) {
143 Location loc = machine.getLoc();
144 llvm::SmallVector<Attribute> stateNames;
146 for (
auto state : machine.getBody().getOps<
StateOp>())
147 stateNames.push_back(b.getStringAttr(state.
getName()));
151 hw::EnumType::get(b.getContext(), b.getArrayAttr(stateNames));
153 OpBuilder::InsertionGuard guard(b);
154 b.setInsertionPointToStart(&typeScope.getBodyRegion().front());
156 b, loc, b.getStringAttr(hwModule.getName() +
"_state_t"),
157 TypeAttr::get(rawEnumType),
nullptr);
159 stateType = hw::TypeAliasType::get(
160 SymbolRefAttr::get(typeScope.getSymNameAttr(),
161 {FlatSymbolRefAttr::get(typedeclEnumType)}),
165 b.setInsertionPointToStart(&hwModule.getBody().front());
166 for (
auto state : machine.getBody().getOps<
StateOp>()) {
167 auto fieldAttr = hw::EnumFieldAttr::get(
168 loc, b.getStringAttr(state.getName()), stateType);
169 auto enumConstantOp = hw::EnumConstantOp::create(
170 b, loc, fieldAttr.getType().getValue(), fieldAttr);
171 setEncoding(state, enumConstantOp,
177Value StateEncoding::encode(
StateOp state) {
178 auto it = stateToValue.find(state);
179 assert(it != stateToValue.end() &&
"state not found");
183StateOp StateEncoding::decode(Value value) {
184 auto it = valueToState.find(value);
185 assert(it != valueToState.end() &&
"encoded state not found");
190std::unique_ptr<sv::CasePattern> StateEncoding::getCasePattern(
StateOp state) {
193 cast<hw::EnumConstantOp>(valueToSrcValue[encode(state)].getDefiningOp())
195 return std::make_unique<sv::CaseEnumPattern>(fieldAttr);
198void StateEncoding::setEncoding(
StateOp state, Value v,
bool wire) {
199 assert(stateToValue.find(state) == stateToValue.end() &&
200 "state already encoded");
204 auto loc = machine.getLoc();
205 auto stateType = getStateType();
206 auto stateEncodingWire = sv::RegOp::create(
207 b, loc, stateType, b.getStringAttr(
"to_" + state.getName()),
208 hw::InnerSymAttr::get(state.getNameAttr()));
213 stateToValue[state] = encodedValue;
214 valueToState[encodedValue] = state;
215 valueToSrcValue[encodedValue] = v;
218class MachineOpConverter {
221 MachineOp machineOp, FlatSymbolRefAttr headerName)
222 : machineOp(machineOp), typeScope(typeScope), b(builder),
223 headerName(headerName) {}
243 LogicalResult dispatch();
246 struct StateConversionResult {
250 llvm::SmallVector<Value> outputs;
253 using StateConversionResults = DenseMap<StateOp, StateConversionResult>;
257 FailureOr<StateConversionResult> convertState(
StateOp state);
263 FailureOr<Value> convertTransitions(
StateOp currentState,
264 ArrayRef<TransitionOp> transitions);
270 FailureOr<Operation *>
271 moveOps(Block *block,
272 llvm::function_ref<
bool(Operation *)> exclude =
nullptr);
275 using StateCaseMapping =
277 std::variant<Value, std::shared_ptr<CaseMuxItem>>>;
288 StateCaseMapping assignmentInState;
292 std::optional<Value> defaultValue = {};
298 void buildStateCaseMux(llvm::MutableArrayRef<CaseMuxItem> assignments);
301 std::unique_ptr<StateEncoding> encoding;
304 llvm::SmallVector<StateOp> orderedStates;
315 llvm::DenseMap< VariableOp, Value>>>
316 stateToVariableUpdates;
332 FlatSymbolRefAttr headerName;
335FailureOr<Operation *>
336MachineOpConverter::moveOps(Block *block,
337 llvm::function_ref<
bool(Operation *)> exclude) {
338 for (
auto &op :
llvm::make_early_inc_range(*block)) {
339 if (!isa<comb::CombDialect, hw::HWDialect, fsm::FSMDialect>(
341 return op.emitOpError()
342 <<
"is unsupported (op from the "
343 << op.getDialect()->getNamespace() <<
" dialect).";
345 if (exclude && exclude(&op))
348 if (op.hasTrait<OpTrait::IsTerminator>())
351 op.moveBefore(hwModuleOp.getBodyBlock(), b.getInsertionPoint());
356void MachineOpConverter::buildStateCaseMux(
357 llvm::MutableArrayRef<CaseMuxItem> assignments) {
361 Value select = assignments.front().select;
364 [&](
const CaseMuxItem &item) {
return item.select == select; }) &&
365 "All assignments must use the same select signal.");
369 for (
auto &assignment : assignments) {
370 if (assignment.defaultValue)
371 sv::BPAssignOp::create(b, assignment.wire.getLoc(), assignment.wire,
372 *assignment.defaultValue);
376 caseMux = sv::CaseOp::create(
377 b, machineOp.getLoc(), CaseStmtType::CaseStmt,
379 machineOp.getNumStates() + 1, [&](
size_t caseIdx) {
381 if (caseIdx == machineOp.getNumStates())
382 return std::unique_ptr<sv::CasePattern>(
383 new sv::CaseDefaultPattern(b.getContext()));
384 StateOp state = orderedStates[caseIdx];
385 return encoding->getCasePattern(state);
389 for (
auto assignment : assignments) {
390 OpBuilder::InsertionGuard g(b);
391 for (
auto [caseInfo, stateOp] :
392 llvm::zip(caseMux.getCases(), orderedStates)) {
393 auto assignmentInState = assignment.assignmentInState.find(stateOp);
394 if (assignmentInState == assignment.assignmentInState.end())
396 b.setInsertionPointToEnd(caseInfo.block);
397 if (
auto v = std::get_if<Value>(&assignmentInState->second); v) {
398 sv::BPAssignOp::create(b, machineOp.getLoc(), assignment.wire, *v);
401 llvm::SmallVector<CaseMuxItem, 4> nestedAssignments;
402 nestedAssignments.push_back(
403 *std::get<std::shared_ptr<CaseMuxItem>>(assignmentInState->second));
404 buildStateCaseMux(nestedAssignments);
410LogicalResult MachineOpConverter::dispatch() {
411 b.setInsertionPoint(machineOp);
412 auto loc = machineOp.getLoc();
413 if (machineOp.getNumStates() < 2)
414 return machineOp.emitOpError() <<
"expected at least 2 states.";
421 SmallVector<hw::PortInfo, 16> ports;
424 hw::HWModuleOp::create(b, loc, machineOp.getSymNameAttr(), ports);
425 hwModuleOp->setAttr(emit::getFragmentsAttrName(),
426 b.getArrayAttr({headerName}));
427 b.setInsertionPointToStart(hwModuleOp.getBodyBlock());
431 for (
auto args :
llvm::zip(machineOp.getArguments(),
433 auto machineArg = std::get<0>(args);
434 auto hwModuleArg = std::get<1>(args);
435 machineArg.replaceAllUsesWith(hwModuleArg);
438 auto clock = hwModuleOp.getBodyBlock()->getArgument(clkRstIdxs.clockIdx);
439 auto reset = hwModuleOp.getBodyBlock()->getArgument(clkRstIdxs.resetIdx);
443 std::make_unique<StateEncoding>(b, typeScope, machineOp, hwModuleOp);
444 auto stateType = encoding->getStateType();
447 sv::RegOp::create(b, loc, stateType, b.getStringAttr(
"state_next"));
450 b, loc, nextStateWireRead, clock, reset,
451 encoding->encode(machineOp.getInitialStateOp()),
454 llvm::DenseMap<VariableOp, sv::RegOp> variableNextStateWires;
455 for (
auto variableOp : machineOp.front().getOps<
fsm::VariableOp>()) {
456 auto initValueAttr = dyn_cast<IntegerAttr>(variableOp.getInitValueAttr());
458 return variableOp.emitOpError() <<
"expected an integer attribute "
459 "for the initial value.";
460 Type varType = variableOp.getType();
461 auto varLoc = variableOp.getLoc();
462 auto varNextState = sv::RegOp::create(
463 b, varLoc, varType, b.getStringAttr(variableOp.getName() +
"_next"));
467 reset, varResetVal, b.getStringAttr(variableOp.getName() +
"_reg"));
468 variableToRegister[variableOp] = variableReg;
469 variableNextStateWires[variableOp] = varNextState;
477 if (failed(moveOps(&machineOp.front(), [](Operation *op) {
478 return isa<fsm::StateOp, fsm::VariableOp>(op);
483 StateCaseMapping nextStateFromState;
484 StateConversionResults stateConvResults;
485 for (
auto state : machineOp.getBody().getOps<
StateOp>()) {
486 auto stateConvRes = convertState(state);
487 if (failed(stateConvRes))
490 stateConvResults[state] = *stateConvRes;
491 orderedStates.push_back(state);
492 nextStateFromState[state] = {stateConvRes->nextState};
496 llvm::SmallVector<CaseMuxItem, 4> outputCaseAssignments;
497 auto hwPortList = hwModuleOp.getPortList();
498 size_t portIndex = 0;
499 for (
auto &port : hwPortList) {
500 if (!port.isOutput())
502 auto outputPortType = port.type;
503 CaseMuxItem outputAssignment;
504 outputAssignment.wire = sv::RegOp::create(
505 b, machineOp.getLoc(), outputPortType,
506 b.getStringAttr(
"output_" + std::to_string(portIndex)));
507 outputAssignment.select = stateReg;
508 for (
auto &state : orderedStates)
509 outputAssignment.assignmentInState[state] = {
510 stateConvResults[state].outputs[portIndex]};
512 outputCaseAssignments.push_back(outputAssignment);
517 llvm::DenseMap<VariableOp, CaseMuxItem> variableCaseMuxItems;
518 for (
auto &[currentState, it] : stateToVariableUpdates) {
519 for (
auto &[targetState, it2] : it) {
520 for (
auto &[variableOp, targetValue] : it2) {
521 auto caseMuxItemIt = variableCaseMuxItems.find(variableOp);
522 if (caseMuxItemIt == variableCaseMuxItems.end()) {
526 variableCaseMuxItems[variableOp];
527 caseMuxItemIt = variableCaseMuxItems.find(variableOp);
529 assert(variableNextStateWires.count(variableOp));
530 caseMuxItemIt->second.wire = variableNextStateWires[variableOp];
531 caseMuxItemIt->second.select = stateReg;
532 caseMuxItemIt->second.defaultValue =
533 variableToRegister[variableOp].getResult();
536 if (!std::get_if<std::shared_ptr<CaseMuxItem>>(
537 &caseMuxItemIt->second.assignmentInState[currentState])) {
540 CaseMuxItem innerCaseMuxItem;
541 innerCaseMuxItem.wire = caseMuxItemIt->second.wire;
542 innerCaseMuxItem.select = nextStateWireRead;
543 caseMuxItemIt->second.assignmentInState[currentState] = {
544 std::make_shared<CaseMuxItem>(innerCaseMuxItem)};
550 auto &innerCaseMuxItem = std::get<std::shared_ptr<CaseMuxItem>>(
551 caseMuxItemIt->second.assignmentInState[currentState]);
552 innerCaseMuxItem->assignmentInState[targetState] = {targetValue};
558 llvm::SmallVector<CaseMuxItem, 4> nextStateCaseAssignments;
559 nextStateCaseAssignments.push_back(
560 CaseMuxItem{nextStateWire, stateReg, nextStateFromState});
561 for (
auto &[_, caseMuxItem] : variableCaseMuxItems)
562 nextStateCaseAssignments.push_back(caseMuxItem);
563 nextStateCaseAssignments.append(outputCaseAssignments.begin(),
564 outputCaseAssignments.end());
567 auto alwaysCombOp = sv::AlwaysCombOp::create(b, loc);
568 OpBuilder::InsertionGuard g(b);
569 b.setInsertionPointToStart(alwaysCombOp.getBodyBlock());
570 buildStateCaseMux(nextStateCaseAssignments);
574 for (
auto &[variableOp, variableReg] : variableToRegister)
575 variableOp.getResult().replaceAllUsesWith(variableReg);
578 llvm::SmallVector<Value> outputPortAssignments;
579 for (
auto outputAssignment : outputCaseAssignments)
580 outputPortAssignments.push_back(
581 sv::ReadInOutOp::create(b, machineOp.
getLoc(), outputAssignment.wire));
585 auto *oldOutputOp = hwModuleOp.getBodyBlock()->getTerminator();
586 hw::OutputOp::create(b, loc, outputPortAssignments);
587 oldOutputOp->erase();
596MachineOpConverter::convertTransitions(
597 StateOp currentState, ArrayRef<TransitionOp> transitions) {
599 if (transitions.empty()) {
602 nextState = encoding->encode(currentState);
605 auto transition = cast<fsm::TransitionOp>(transitions.front());
606 nextState = encoding->encode(transition.getNextStateOp());
609 if (transition.hasAction()) {
612 auto actionMoveOpsRes =
613 moveOps(&transition.getAction().front(),
614 [](Operation *op) { return isa<fsm::UpdateOp>(op); });
615 if (failed(actionMoveOpsRes))
619 DenseMap<fsm::VariableOp, Value> variableUpdates;
620 for (
auto updateOp : transition.getAction().getOps<
fsm::UpdateOp>()) {
621 VariableOp variableOp = updateOp.getVariableOp();
622 variableUpdates[variableOp] = updateOp.getValue();
625 stateToVariableUpdates[currentState][transition.getNextStateOp()] =
630 if (transition.hasGuard()) {
633 auto guardOpRes = moveOps(&transition.getGuard().front());
634 if (failed(guardOpRes))
637 auto guardOp = cast<ReturnOp>(*guardOpRes);
638 assert(guardOp &&
"guard should be defined");
639 auto guard = guardOp.getOperand();
640 auto otherNextState =
641 convertTransitions(currentState, transitions.drop_front());
642 if (failed(otherNextState))
645 b, transition.getLoc(), guard, nextState, *otherNextState,
false);
646 nextState = nextStateMux;
650 assert(nextState &&
"next state should be defined");
654FailureOr<MachineOpConverter::StateConversionResult>
655MachineOpConverter::convertState(
StateOp state) {
656 MachineOpConverter::StateConversionResult res;
660 if (!state.getOutput().empty()) {
661 auto outputOpRes = moveOps(&state.getOutput().front());
662 if (failed(outputOpRes))
665 OutputOp outputOp = cast<fsm::OutputOp>(*outputOpRes);
666 res.outputs = outputOp.getOperands();
669 SmallVector<TransitionOp> transitions;
670 for (
auto &op : state.getTransitions().getOps()) {
671 if (
auto transOp = dyn_cast<TransitionOp>(op)) {
672 transitions.push_back(transOp);
676 auto opClone = b.clone(op);
677 for (
auto [i, res] :
llvm::enumerate(op.getResults()))
678 res.replaceAllUsesWith(opClone->getResult(i));
683 auto nextStateRes = convertTransitions(state, transitions);
684 if (failed(nextStateRes))
686 res.nextState = *nextStateRes;
690struct FSMToSVPass :
public circt::impl::ConvertFSMToSVBase<FSMToSVPass> {
691 void runOnOperation()
override;
694void FSMToSVPass::runOnOperation() {
695 auto module = getOperation();
696 auto loc =
module.getLoc();
697 auto b = OpBuilder(module);
700 auto machineOps = llvm::to_vector(module.getOps<
MachineOp>());
701 if (machineOps.empty()) {
702 markAllAnalysesPreserved();
709 b.setInsertionPointToStart(module.getBody());
712 typeScope.getBodyRegion().push_back(
new Block());
714 auto file = emit::FileOp::create(b, loc,
"fsm_enum_typedefs.sv", [&] {
715 emit::RefOp::create(b, loc,
716 FlatSymbolRefAttr::get(typeScope.getSymNameAttr()));
718 auto fragment = emit::FragmentOp::create(b, loc,
"FSM_ENUM_TYPEDEFS", [&] {
719 sv::VerbatimOp::create(b, loc,
"`include \"" + file.getFileName() +
"\"");
722 auto headerName = FlatSymbolRefAttr::get(fragment.getSymNameAttr());
725 for (
auto machineOp : machineOps) {
726 MachineOpConverter converter(b, typeScope, machineOp, headerName);
728 if (failed(converter.dispatch())) {
735 llvm::SmallVector<HWInstanceOp> instances;
736 module.walk([&](HWInstanceOp instance) { instances.push_back(instance); });
737 for (
auto instance : instances) {
739 module.lookupSymbol<hw::HWModuleOp>(instance.getMachine());
741 "FSM machine should have been converted to a hw.module");
743 b.setInsertionPoint(instance);
744 llvm::SmallVector<Value, 4> operands;
745 llvm::transform(instance.getOperands(), std::back_inserter(operands),
746 [&](
auto operand) { return operand; });
747 auto hwInstance = hw::InstanceOp::create(
748 b, instance.getLoc(), fsmHWModule, b.getStringAttr(instance.getName()),
750 instance.replaceAllUsesWith(hwInstance);
754 assert(!typeScope.getBodyBlock()->empty() &&
"missing type decls");
760 return std::make_unique<FSMToSVPass>();
assert(baseType &&"element must be base type")
static ClkRstIdxs getMachinePortInfo(SmallVectorImpl< hw::PortInfo > &ports, MachineOp machine, OpBuilder &b)
static void cloneConstantsIntoRegion(Region ®ion, OpBuilder &builder)
static Location getLoc(DefSlot slot)
static Block * getBodyBlock(FModuleLike mod)
create(str sym_name, Type type, str verilog_name=None)
create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createConvertFSMToSVPass()
This holds the name, type, direction of a module's ports.
size_t argNum
This is the argument index or the result index depending on the direction.