17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/RegionUtils.h"
19#include "llvm/ADT/TypeSwitch.h"
25#define GEN_PASS_DEF_CONVERTFSMTOSV
26#include "circt/Conversion/Passes.h.inc"
43 machine.getHWPortInfo(ports);
44 ClkRstIdxs specialPorts;
48 clock.
name = b.getStringAttr(
"clk");
49 clock.
dir = hw::ModulePort::Direction::Input;
50 clock.
type = seq::ClockType::get(b.getContext());
51 clock.
argNum = machine.getNumArguments();
52 ports.push_back(clock);
53 specialPorts.clockIdx = clock.
argNum;
57 reset.
name = b.getStringAttr(
"rst");
58 reset.
dir = hw::ModulePort::Direction::Input;
59 reset.
type = b.getI1Type();
60 reset.
argNum = machine.getNumArguments() + 1;
61 ports.push_back(reset);
62 specialPorts.resetIdx = reset.
argNum;
70 llvm::SetVector<Value> captures;
71 getUsedValuesDefinedAbove(region, region, captures);
73 OpBuilder::InsertionGuard guard(builder);
74 builder.setInsertionPointToStart(®ion.front());
77 for (
auto &capture : captures) {
78 Operation *op = capture.getDefiningOp();
79 if (!op || !op->hasTrait<OpTrait::ConstantLike>())
82 Operation *cloned = builder.clone(*op);
83 for (
auto [orig, replacement] :
84 llvm::zip(op->getResults(), cloned->getResults()))
85 replaceAllUsesInRegionWith(orig, replacement, region);
106 Type getStateType() {
return stateType; }
109 std::unique_ptr<sv::CasePattern> getCasePattern(
StateOp state);
118 void setEncoding(
StateOp state, Value v,
bool wire =
false);
142 : typeScope(typeScope), b(b), machine(machine), hwModule(hwModule) {
143 Location loc = machine.getLoc();
144 llvm::SmallVector<Attribute> stateNames;
146 for (
auto state : machine.getBody().getOps<
StateOp>())
147 stateNames.push_back(b.getStringAttr(state.
getName()));
151 hw::EnumType::get(b.getContext(), b.getArrayAttr(stateNames));
153 OpBuilder::InsertionGuard guard(b);
154 b.setInsertionPointToStart(&typeScope.getBodyRegion().front());
156 loc, b.getStringAttr(hwModule.getName() +
"_state_t"),
157 TypeAttr::get(rawEnumType),
nullptr);
159 stateType = hw::TypeAliasType::get(
160 SymbolRefAttr::get(typeScope.getSymNameAttr(),
161 {FlatSymbolRefAttr::get(typedeclEnumType)}),
165 b.setInsertionPointToStart(&hwModule.getBody().front());
166 for (
auto state : machine.getBody().getOps<
StateOp>()) {
167 auto fieldAttr = hw::EnumFieldAttr::get(
168 loc, b.getStringAttr(state.getName()), stateType);
169 auto enumConstantOp = b.create<hw::EnumConstantOp>(
170 loc, fieldAttr.getType().getValue(), fieldAttr);
171 setEncoding(state, enumConstantOp,
177Value StateEncoding::encode(
StateOp state) {
178 auto it = stateToValue.find(state);
179 assert(it != stateToValue.end() &&
"state not found");
183StateOp StateEncoding::decode(Value value) {
184 auto it = valueToState.find(value);
185 assert(it != valueToState.end() &&
"encoded state not found");
190std::unique_ptr<sv::CasePattern> StateEncoding::getCasePattern(
StateOp state) {
193 cast<hw::EnumConstantOp>(valueToSrcValue[encode(state)].getDefiningOp())
195 return std::make_unique<sv::CaseEnumPattern>(fieldAttr);
198void StateEncoding::setEncoding(
StateOp state, Value v,
bool wire) {
199 assert(stateToValue.find(state) == stateToValue.end() &&
200 "state already encoded");
204 auto loc = machine.getLoc();
205 auto stateType = getStateType();
206 auto stateEncodingWire = b.create<
sv::RegOp>(
207 loc, stateType, b.getStringAttr(
"to_" + state.getName()),
208 hw::InnerSymAttr::get(state.getNameAttr()));
213 stateToValue[state] = encodedValue;
214 valueToState[encodedValue] = state;
215 valueToSrcValue[encodedValue] = v;
218class MachineOpConverter {
221 MachineOp machineOp, FlatSymbolRefAttr headerName)
222 : machineOp(machineOp), typeScope(typeScope), b(builder),
223 headerName(headerName) {}
243 LogicalResult dispatch();
246 struct StateConversionResult {
250 llvm::SmallVector<Value> outputs;
253 using StateConversionResults = DenseMap<StateOp, StateConversionResult>;
257 FailureOr<StateConversionResult> convertState(
StateOp state);
263 FailureOr<Value> convertTransitions(
StateOp currentState,
264 ArrayRef<TransitionOp> transitions);
270 FailureOr<Operation *>
271 moveOps(Block *block,
272 llvm::function_ref<
bool(Operation *)> exclude =
nullptr);
275 using StateCaseMapping =
277 std::variant<Value, std::shared_ptr<CaseMuxItem>>>;
288 StateCaseMapping assignmentInState;
292 std::optional<Value> defaultValue = {};
298 void buildStateCaseMux(llvm::MutableArrayRef<CaseMuxItem> assignments);
301 std::unique_ptr<StateEncoding> encoding;
304 llvm::SmallVector<StateOp> orderedStates;
315 llvm::DenseMap< VariableOp, Value>>>
316 stateToVariableUpdates;
332 FlatSymbolRefAttr headerName;
335FailureOr<Operation *>
336MachineOpConverter::moveOps(Block *block,
337 llvm::function_ref<
bool(Operation *)> exclude) {
338 for (
auto &op :
llvm::make_early_inc_range(*block)) {
339 if (!isa<comb::CombDialect, hw::HWDialect, fsm::FSMDialect>(
341 return op.emitOpError()
342 <<
"is unsupported (op from the "
343 << op.getDialect()->getNamespace() <<
" dialect).";
345 if (exclude && exclude(&op))
348 if (op.hasTrait<OpTrait::IsTerminator>())
351 op.moveBefore(hwModuleOp.getBodyBlock(), b.getInsertionPoint());
356void MachineOpConverter::buildStateCaseMux(
357 llvm::MutableArrayRef<CaseMuxItem> assignments) {
361 Value select = assignments.front().select;
364 [&](
const CaseMuxItem &item) {
return item.select == select; }) &&
365 "All assignments must use the same select signal.");
369 for (
auto &assignment : assignments) {
370 if (assignment.defaultValue)
371 b.create<sv::BPAssignOp>(assignment.wire.getLoc(), assignment.wire,
372 *assignment.defaultValue);
376 caseMux = b.create<sv::CaseOp>(
377 machineOp.getLoc(), CaseStmtType::CaseStmt,
379 machineOp.getNumStates() + 1, [&](
size_t caseIdx) {
381 if (caseIdx == machineOp.getNumStates())
382 return std::unique_ptr<sv::CasePattern>(
384 StateOp state = orderedStates[caseIdx];
385 return encoding->getCasePattern(state);
389 for (
auto assignment : assignments) {
390 OpBuilder::InsertionGuard g(b);
391 for (
auto [caseInfo, stateOp] :
392 llvm::zip(caseMux.getCases(), orderedStates)) {
393 auto assignmentInState = assignment.assignmentInState.find(stateOp);
394 if (assignmentInState == assignment.assignmentInState.end())
396 b.setInsertionPointToEnd(caseInfo.block);
397 if (
auto v = std::get_if<Value>(&assignmentInState->second); v) {
398 b.create<sv::BPAssignOp>(machineOp.getLoc(), assignment.wire, *v);
401 llvm::SmallVector<CaseMuxItem, 4> nestedAssignments;
402 nestedAssignments.push_back(
403 *std::get<std::shared_ptr<CaseMuxItem>>(assignmentInState->second));
404 buildStateCaseMux(nestedAssignments);
410LogicalResult MachineOpConverter::dispatch() {
411 b.setInsertionPoint(machineOp);
412 auto loc = machineOp.getLoc();
413 if (machineOp.getNumStates() < 2)
414 return machineOp.emitOpError() <<
"expected at least 2 states.";
421 SmallVector<hw::PortInfo, 16> ports;
423 hwModuleOp = b.create<
hw::HWModuleOp>(loc, machineOp.getSymNameAttr(), ports);
424 hwModuleOp->setAttr(emit::getFragmentsAttrName(),
425 b.getArrayAttr({headerName}));
426 b.setInsertionPointToStart(hwModuleOp.getBodyBlock());
430 for (
auto args :
llvm::zip(machineOp.getArguments(),
432 auto machineArg = std::get<0>(args);
433 auto hwModuleArg = std::get<1>(args);
434 machineArg.replaceAllUsesWith(hwModuleArg);
437 auto clock = hwModuleOp.getBodyBlock()->getArgument(clkRstIdxs.clockIdx);
438 auto reset = hwModuleOp.getBodyBlock()->getArgument(clkRstIdxs.resetIdx);
442 std::make_unique<StateEncoding>(b, typeScope, machineOp, hwModuleOp);
443 auto stateType = encoding->getStateType();
447 auto nextStateWireRead = b.create<
sv::ReadInOutOp>(loc, nextStateWire);
449 loc, nextStateWireRead, clock, reset,
450 encoding->encode(machineOp.getInitialStateOp()),
453 llvm::DenseMap<VariableOp, sv::RegOp> variableNextStateWires;
454 for (
auto variableOp : machineOp.front().getOps<
fsm::VariableOp>()) {
455 auto initValueAttr = dyn_cast<IntegerAttr>(variableOp.getInitValueAttr());
457 return variableOp.emitOpError() <<
"expected an integer attribute "
458 "for the initial value.";
459 Type varType = variableOp.getType();
460 auto varLoc = variableOp.getLoc();
462 varLoc, varType, b.getStringAttr(variableOp.getName() +
"_next"));
463 auto varResetVal = b.create<
hw::ConstantOp>(varLoc, initValueAttr);
466 varResetVal, b.getStringAttr(variableOp.getName() +
"_reg"));
467 variableToRegister[variableOp] = variableReg;
468 variableNextStateWires[variableOp] = varNextState;
476 if (failed(moveOps(&machineOp.front(), [](Operation *op) {
477 return isa<fsm::StateOp, fsm::VariableOp>(op);
482 StateCaseMapping nextStateFromState;
483 StateConversionResults stateConvResults;
484 for (
auto state : machineOp.getBody().getOps<
StateOp>()) {
485 auto stateConvRes = convertState(state);
486 if (failed(stateConvRes))
489 stateConvResults[state] = *stateConvRes;
490 orderedStates.push_back(state);
491 nextStateFromState[state] = {stateConvRes->nextState};
495 llvm::SmallVector<CaseMuxItem, 4> outputCaseAssignments;
496 auto hwPortList = hwModuleOp.getPortList();
497 size_t portIndex = 0;
498 for (
auto &port : hwPortList) {
499 if (!port.isOutput())
501 auto outputPortType = port.type;
502 CaseMuxItem outputAssignment;
503 outputAssignment.wire = b.create<
sv::RegOp>(
504 machineOp.getLoc(), outputPortType,
505 b.getStringAttr(
"output_" + std::to_string(portIndex)));
506 outputAssignment.select = stateReg;
507 for (
auto &state : orderedStates)
508 outputAssignment.assignmentInState[state] = {
509 stateConvResults[state].outputs[portIndex]};
511 outputCaseAssignments.push_back(outputAssignment);
516 llvm::DenseMap<VariableOp, CaseMuxItem> variableCaseMuxItems;
517 for (
auto &[currentState, it] : stateToVariableUpdates) {
518 for (
auto &[targetState, it2] : it) {
519 for (
auto &[variableOp, targetValue] : it2) {
520 auto caseMuxItemIt = variableCaseMuxItems.find(variableOp);
521 if (caseMuxItemIt == variableCaseMuxItems.end()) {
525 variableCaseMuxItems[variableOp];
526 caseMuxItemIt = variableCaseMuxItems.find(variableOp);
528 assert(variableNextStateWires.count(variableOp));
529 caseMuxItemIt->second.wire = variableNextStateWires[variableOp];
530 caseMuxItemIt->second.select = stateReg;
531 caseMuxItemIt->second.defaultValue =
532 variableToRegister[variableOp].getResult();
535 if (!std::get_if<std::shared_ptr<CaseMuxItem>>(
536 &caseMuxItemIt->second.assignmentInState[currentState])) {
539 CaseMuxItem innerCaseMuxItem;
540 innerCaseMuxItem.wire = caseMuxItemIt->second.wire;
541 innerCaseMuxItem.select = nextStateWireRead;
542 caseMuxItemIt->second.assignmentInState[currentState] = {
543 std::make_shared<CaseMuxItem>(innerCaseMuxItem)};
549 auto &innerCaseMuxItem = std::get<std::shared_ptr<CaseMuxItem>>(
550 caseMuxItemIt->second.assignmentInState[currentState]);
551 innerCaseMuxItem->assignmentInState[targetState] = {targetValue};
557 llvm::SmallVector<CaseMuxItem, 4> nextStateCaseAssignments;
558 nextStateCaseAssignments.push_back(
559 CaseMuxItem{nextStateWire, stateReg, nextStateFromState});
560 for (
auto &[_, caseMuxItem] : variableCaseMuxItems)
561 nextStateCaseAssignments.push_back(caseMuxItem);
562 nextStateCaseAssignments.append(outputCaseAssignments.begin(),
563 outputCaseAssignments.end());
566 auto alwaysCombOp = b.create<sv::AlwaysCombOp>(loc);
567 OpBuilder::InsertionGuard g(b);
568 b.setInsertionPointToStart(alwaysCombOp.getBodyBlock());
569 buildStateCaseMux(nextStateCaseAssignments);
573 for (
auto &[variableOp, variableReg] : variableToRegister)
574 variableOp.getResult().replaceAllUsesWith(variableReg);
577 llvm::SmallVector<Value> outputPortAssignments;
578 for (
auto outputAssignment : outputCaseAssignments)
579 outputPortAssignments.push_back(
580 b.create<
sv::ReadInOutOp>(machineOp.getLoc(), outputAssignment.wire));
584 auto *oldOutputOp = hwModuleOp.getBodyBlock()->getTerminator();
585 b.create<hw::OutputOp>(loc, outputPortAssignments);
586 oldOutputOp->erase();
595MachineOpConverter::convertTransitions(
596 StateOp currentState, ArrayRef<TransitionOp> transitions) {
598 if (transitions.empty()) {
601 nextState = encoding->encode(currentState);
604 auto transition = cast<fsm::TransitionOp>(transitions.front());
605 nextState = encoding->encode(transition.getNextStateOp());
608 if (transition.hasAction()) {
611 auto actionMoveOpsRes =
612 moveOps(&transition.getAction().front(),
613 [](Operation *op) { return isa<fsm::UpdateOp>(op); });
614 if (failed(actionMoveOpsRes))
618 DenseMap<fsm::VariableOp, Value> variableUpdates;
619 for (
auto updateOp : transition.getAction().getOps<
fsm::UpdateOp>()) {
620 VariableOp variableOp = updateOp.getVariableOp();
621 variableUpdates[variableOp] = updateOp.getValue();
624 stateToVariableUpdates[currentState][transition.getNextStateOp()] =
629 if (transition.hasGuard()) {
632 auto guardOpRes = moveOps(&transition.getGuard().front());
633 if (failed(guardOpRes))
636 auto guardOp = cast<ReturnOp>(*guardOpRes);
637 assert(guardOp &&
"guard should be defined");
638 auto guard = guardOp.getOperand();
639 auto otherNextState =
640 convertTransitions(currentState, transitions.drop_front());
641 if (failed(otherNextState))
644 transition.getLoc(), guard, nextState, *otherNextState,
false);
645 nextState = nextStateMux;
649 assert(nextState &&
"next state should be defined");
653FailureOr<MachineOpConverter::StateConversionResult>
654MachineOpConverter::convertState(
StateOp state) {
655 MachineOpConverter::StateConversionResult res;
659 if (!state.getOutput().empty()) {
660 auto outputOpRes = moveOps(&state.getOutput().front());
661 if (failed(outputOpRes))
664 OutputOp outputOp = cast<fsm::OutputOp>(*outputOpRes);
665 res.outputs = outputOp.getOperands();
668 auto transitions = llvm::SmallVector<TransitionOp>(
672 auto nextStateRes = convertTransitions(state, transitions);
673 if (failed(nextStateRes))
675 res.nextState = *nextStateRes;
679struct FSMToSVPass :
public circt::impl::ConvertFSMToSVBase<FSMToSVPass> {
680 void runOnOperation()
override;
683void FSMToSVPass::runOnOperation() {
684 auto module = getOperation();
685 auto loc =
module.getLoc();
686 auto b = OpBuilder(module);
689 auto machineOps = llvm::to_vector(module.getOps<
MachineOp>());
690 if (machineOps.empty()) {
691 markAllAnalysesPreserved();
698 b.setInsertionPointToStart(module.getBody());
701 typeScope.getBodyRegion().push_back(
new Block());
703 auto file = b.
create<emit::FileOp>(loc,
"fsm_enum_typedefs.sv", [&] {
704 b.
create<emit::RefOp>(loc,
705 FlatSymbolRefAttr::get(typeScope.getSymNameAttr()));
707 auto fragment = b.create<emit::FragmentOp>(loc,
"FSM_ENUM_TYPEDEFS", [&] {
708 b.create<sv::VerbatimOp>(loc,
"`include \"" + file.getFileName() +
"\"");
711 auto headerName = FlatSymbolRefAttr::get(fragment.getSymNameAttr());
714 for (
auto machineOp : machineOps) {
715 MachineOpConverter converter(b, typeScope, machineOp, headerName);
717 if (failed(converter.dispatch())) {
724 llvm::SmallVector<HWInstanceOp> instances;
725 module.walk([&](HWInstanceOp instance) { instances.push_back(instance); });
726 for (
auto instance : instances) {
728 module.lookupSymbol<hw::HWModuleOp>(instance.getMachine());
730 "FSM machine should have been converted to a hw.module");
732 b.setInsertionPoint(instance);
733 llvm::SmallVector<Value, 4> operands;
734 llvm::transform(instance.getOperands(), std::back_inserter(operands),
735 [&](
auto operand) { return operand; });
736 auto hwInstance = b.create<hw::InstanceOp>(
737 instance.getLoc(), fsmHWModule, b.getStringAttr(instance.getName()),
739 instance.replaceAllUsesWith(hwInstance);
743 assert(!typeScope.getBodyBlock()->empty() &&
"missing type decls");
749 return std::make_unique<FSMToSVPass>();
assert(baseType &&"element must be base type")
static ClkRstIdxs getMachinePortInfo(SmallVectorImpl< hw::PortInfo > &ports, MachineOp machine, OpBuilder &b)
static void cloneConstantsIntoRegion(Region ®ion, OpBuilder &builder)
static Block * getBodyBlock(FModuleLike mod)
create(str sym_name, Type type, str verilog_name=None)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createConvertFSMToSVPass()
This holds the name, type, direction of a module's ports.
size_t argNum
This is the argument index or the result index depending on the direction.