17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/RegionUtils.h"
19 #include "llvm/ADT/TypeSwitch.h"
25 #define GEN_PASS_DEF_CONVERTFSMTOSV
26 #include "circt/Conversion/Passes.h.inc"
30 using namespace circt;
43 machine.getHWPortInfo(ports);
44 ClkRstIdxs specialPorts;
48 clock.name = b.getStringAttr(
"clk");
51 clock.argNum = machine.getNumArguments();
52 ports.push_back(clock);
53 specialPorts.clockIdx = clock.argNum;
57 reset.name = b.getStringAttr(
"rst");
59 reset.type = b.getI1Type();
60 reset.argNum = machine.getNumArguments() + 1;
61 ports.push_back(reset);
62 specialPorts.resetIdx = reset.argNum;
70 llvm::SetVector<Value> captures;
71 getUsedValuesDefinedAbove(region, region, captures);
73 OpBuilder::InsertionGuard guard(builder);
74 builder.setInsertionPointToStart(®ion.front());
77 for (
auto &capture : captures) {
78 Operation *op = capture.getDefiningOp();
79 if (!op || !op->hasTrait<OpTrait::ConstantLike>())
82 Operation *cloned = builder.clone(*op);
83 for (
auto [orig, replacement] :
84 llvm::zip(op->getResults(), cloned->getResults()))
85 replaceAllUsesInRegionWith(orig, replacement, region);
106 Type getStateType() {
return stateType; }
109 std::unique_ptr<sv::CasePattern> getCasePattern(
StateOp state);
118 void setEncoding(
StateOp state, Value v,
bool wire =
false);
142 : typeScope(typeScope), b(b), machine(machine), hwModule(hwModule) {
143 Location loc = machine.getLoc();
144 llvm::SmallVector<Attribute> stateNames;
146 for (
auto state : machine.getBody().getOps<
StateOp>())
147 stateNames.push_back(b.getStringAttr(state.getName()));
153 OpBuilder::InsertionGuard guard(b);
154 b.setInsertionPointToStart(&typeScope.getBodyRegion().front());
156 loc, b.getStringAttr(hwModule.getName() +
"_state_t"),
161 {FlatSymbolRefAttr::get(typedeclEnumType)}),
165 b.setInsertionPointToStart(&hwModule.getBody().front());
166 for (
auto state : machine.getBody().getOps<
StateOp>()) {
168 loc, b.getStringAttr(state.getName()), stateType);
169 auto enumConstantOp = b.create<hw::EnumConstantOp>(
170 loc, fieldAttr.getType().getValue(), fieldAttr);
171 setEncoding(state, enumConstantOp,
177 Value StateEncoding::encode(
StateOp state) {
178 auto it = stateToValue.find(state);
179 assert(it != stateToValue.end() &&
"state not found");
183 StateOp StateEncoding::decode(Value value) {
184 auto it = valueToState.find(value);
185 assert(it != valueToState.end() &&
"encoded state not found");
190 std::unique_ptr<sv::CasePattern> StateEncoding::getCasePattern(
StateOp state) {
193 cast<hw::EnumConstantOp>(valueToSrcValue[encode(state)].getDefiningOp())
195 return std::make_unique<sv::CaseEnumPattern>(fieldAttr);
198 void StateEncoding::setEncoding(
StateOp state, Value v,
bool wire) {
199 assert(stateToValue.find(state) == stateToValue.end() &&
200 "state already encoded");
204 auto loc = machine.getLoc();
205 auto stateType = getStateType();
206 auto stateEncodingWire = b.create<
sv::RegOp>(
207 loc, stateType, b.getStringAttr(
"to_" + state.getName()),
213 stateToValue[state] = encodedValue;
214 valueToState[encodedValue] = state;
215 valueToSrcValue[encodedValue] = v;
218 class MachineOpConverter {
221 MachineOp machineOp, FlatSymbolRefAttr headerName)
222 : machineOp(machineOp), typeScope(typeScope), b(builder),
223 headerName(headerName) {}
243 LogicalResult dispatch();
246 struct StateConversionResult {
250 llvm::SmallVector<Value> outputs;
253 using StateConversionResults = DenseMap<StateOp, StateConversionResult>;
257 FailureOr<StateConversionResult> convertState(
StateOp state);
263 FailureOr<Value> convertTransitions(
StateOp currentState,
264 ArrayRef<TransitionOp> transitions);
270 FailureOr<Operation *>
271 moveOps(Block *block,
272 llvm::function_ref<
bool(Operation *)> exclude =
nullptr);
275 using StateCaseMapping =
277 std::variant<Value, std::shared_ptr<CaseMuxItem>>>;
288 StateCaseMapping assignmentInState;
292 std::optional<Value> defaultValue = {};
298 void buildStateCaseMux(llvm::MutableArrayRef<CaseMuxItem> assignments);
301 std::unique_ptr<StateEncoding> encoding;
304 llvm::SmallVector<StateOp> orderedStates;
315 llvm::DenseMap< VariableOp, Value>>>
316 stateToVariableUpdates;
332 FlatSymbolRefAttr headerName;
335 FailureOr<Operation *>
336 MachineOpConverter::moveOps(Block *block,
337 llvm::function_ref<
bool(Operation *)> exclude) {
338 for (
auto &op : llvm::make_early_inc_range(*block)) {
339 if (!isa<comb::CombDialect, hw::HWDialect, fsm::FSMDialect>(
341 return op.emitOpError()
342 <<
"is unsupported (op from the "
343 << op.getDialect()->getNamespace() <<
" dialect).";
345 if (exclude && exclude(&op))
348 if (op.hasTrait<OpTrait::IsTerminator>())
351 op.moveBefore(hwModuleOp.getBodyBlock(), b.getInsertionPoint());
356 void MachineOpConverter::buildStateCaseMux(
357 llvm::MutableArrayRef<CaseMuxItem> assignments) {
361 Value select = assignments.front().select;
364 [&](
const CaseMuxItem &item) {
return item.select == select; }) &&
365 "All assignments must use the same select signal.");
369 for (
auto &assignment : assignments) {
370 if (assignment.defaultValue)
371 b.create<sv::BPAssignOp>(assignment.wire.getLoc(), assignment.wire,
372 *assignment.defaultValue);
376 caseMux = b.create<sv::CaseOp>(
377 machineOp.getLoc(), CaseStmtType::CaseStmt,
379 machineOp.getNumStates() + 1, [&](
size_t caseIdx) {
381 if (caseIdx == machineOp.getNumStates())
382 return std::unique_ptr<sv::CasePattern>(
383 new sv::CaseDefaultPattern(b.getContext()));
384 StateOp state = orderedStates[caseIdx];
385 return encoding->getCasePattern(state);
389 for (
auto assignment : assignments) {
390 OpBuilder::InsertionGuard g(b);
391 for (
auto [caseInfo, stateOp] :
392 llvm::zip(caseMux.getCases(), orderedStates)) {
393 auto assignmentInState = assignment.assignmentInState.find(stateOp);
394 if (assignmentInState == assignment.assignmentInState.end())
396 b.setInsertionPointToEnd(caseInfo.block);
397 if (
auto v = std::get_if<Value>(&assignmentInState->second); v) {
398 b.create<sv::BPAssignOp>(machineOp.getLoc(), assignment.wire, *v);
401 llvm::SmallVector<CaseMuxItem, 4> nestedAssignments;
402 nestedAssignments.push_back(
403 *
std::get<std::shared_ptr<CaseMuxItem>>(assignmentInState->second));
404 buildStateCaseMux(nestedAssignments);
410 LogicalResult MachineOpConverter::dispatch() {
411 b.setInsertionPoint(machineOp);
412 auto loc = machineOp.getLoc();
413 if (machineOp.getNumStates() < 2)
414 return machineOp.emitOpError() <<
"expected at least 2 states.";
421 SmallVector<hw::PortInfo, 16> ports;
423 hwModuleOp = b.create<
hw::HWModuleOp>(loc, machineOp.getSymNameAttr(), ports);
425 b.getArrayAttr({headerName}));
426 b.setInsertionPointToStart(hwModuleOp.getBodyBlock());
430 for (
auto args : llvm::zip(machineOp.getArguments(),
431 hwModuleOp.getBodyBlock()->getArguments())) {
432 auto machineArg = std::get<0>(args);
433 auto hwModuleArg = std::get<1>(args);
434 machineArg.replaceAllUsesWith(hwModuleArg);
437 auto clock = hwModuleOp.getBodyBlock()->getArgument(clkRstIdxs.clockIdx);
438 auto reset = hwModuleOp.getBodyBlock()->getArgument(clkRstIdxs.resetIdx);
442 std::make_unique<StateEncoding>(b, typeScope, machineOp, hwModuleOp);
443 auto stateType = encoding->getStateType();
447 auto nextStateWireRead = b.create<
sv::ReadInOutOp>(loc, nextStateWire);
449 loc, nextStateWireRead, clock, reset,
450 encoding->encode(machineOp.getInitialStateOp()),
453 llvm::DenseMap<VariableOp, sv::RegOp> variableNextStateWires;
454 for (
auto variableOp : machineOp.front().getOps<fsm::VariableOp>()) {
455 auto initValueAttr = dyn_cast<IntegerAttr>(variableOp.getInitValueAttr());
457 return variableOp.emitOpError() <<
"expected an integer attribute "
458 "for the initial value.";
459 Type varType = variableOp.getType();
460 auto varLoc = variableOp.getLoc();
462 varLoc, varType, b.getStringAttr(variableOp.getName() +
"_next"));
463 auto varResetVal = b.create<
hw::ConstantOp>(varLoc, initValueAttr);
466 varResetVal, b.getStringAttr(variableOp.getName() +
"_reg"));
467 variableToRegister[variableOp] = variableReg;
468 variableNextStateWires[variableOp] = varNextState;
476 if (failed(moveOps(&machineOp.front(), [](Operation *op) {
477 return isa<fsm::StateOp, fsm::VariableOp>(op);
482 StateCaseMapping nextStateFromState;
483 StateConversionResults stateConvResults;
484 for (
auto state : machineOp.getBody().getOps<
StateOp>()) {
485 auto stateConvRes = convertState(state);
486 if (failed(stateConvRes))
489 stateConvResults[state] = *stateConvRes;
490 orderedStates.push_back(state);
491 nextStateFromState[state] = {stateConvRes->nextState};
495 llvm::SmallVector<CaseMuxItem, 4> outputCaseAssignments;
496 auto hwPortList = hwModuleOp.getPortList();
497 size_t portIndex = 0;
498 for (
auto &port : hwPortList) {
499 if (!port.isOutput())
501 auto outputPortType = port.type;
502 CaseMuxItem outputAssignment;
503 outputAssignment.wire = b.create<
sv::RegOp>(
504 machineOp.getLoc(), outputPortType,
505 b.getStringAttr(
"output_" + std::to_string(portIndex)));
506 outputAssignment.select = stateReg;
507 for (
auto &state : orderedStates)
508 outputAssignment.assignmentInState[state] = {
509 stateConvResults[state].outputs[portIndex]};
511 outputCaseAssignments.push_back(outputAssignment);
516 llvm::DenseMap<VariableOp, CaseMuxItem> variableCaseMuxItems;
517 for (
auto &[currentState, it] : stateToVariableUpdates) {
518 for (
auto &[targetState, it2] : it) {
519 for (
auto &[variableOp, targetValue] : it2) {
520 auto caseMuxItemIt = variableCaseMuxItems.find(variableOp);
521 if (caseMuxItemIt == variableCaseMuxItems.end()) {
525 variableCaseMuxItems[variableOp];
526 caseMuxItemIt = variableCaseMuxItems.find(variableOp);
528 assert(variableNextStateWires.count(variableOp));
529 caseMuxItemIt->second.wire = variableNextStateWires[variableOp];
530 caseMuxItemIt->second.select = stateReg;
531 caseMuxItemIt->second.defaultValue =
532 variableToRegister[variableOp].getResult();
535 if (!std::get_if<std::shared_ptr<CaseMuxItem>>(
536 &caseMuxItemIt->second.assignmentInState[currentState])) {
539 CaseMuxItem innerCaseMuxItem;
540 innerCaseMuxItem.wire = caseMuxItemIt->second.wire;
541 innerCaseMuxItem.select = nextStateWireRead;
542 caseMuxItemIt->second.assignmentInState[currentState] = {
543 std::make_shared<CaseMuxItem>(innerCaseMuxItem)};
549 auto &innerCaseMuxItem = std::get<std::shared_ptr<CaseMuxItem>>(
550 caseMuxItemIt->second.assignmentInState[currentState]);
551 innerCaseMuxItem->assignmentInState[targetState] = {targetValue};
557 llvm::SmallVector<CaseMuxItem, 4> nextStateCaseAssignments;
558 nextStateCaseAssignments.push_back(
559 CaseMuxItem{nextStateWire, stateReg, nextStateFromState});
560 for (
auto &[_, caseMuxItem] : variableCaseMuxItems)
561 nextStateCaseAssignments.push_back(caseMuxItem);
562 nextStateCaseAssignments.append(outputCaseAssignments.begin(),
563 outputCaseAssignments.end());
566 auto alwaysCombOp = b.create<sv::AlwaysCombOp>(loc);
567 OpBuilder::InsertionGuard g(b);
568 b.setInsertionPointToStart(alwaysCombOp.getBodyBlock());
569 buildStateCaseMux(nextStateCaseAssignments);
573 for (
auto &[variableOp, variableReg] : variableToRegister)
574 variableOp.getResult().replaceAllUsesWith(variableReg);
577 llvm::SmallVector<Value> outputPortAssignments;
578 for (
auto outputAssignment : outputCaseAssignments)
579 outputPortAssignments.push_back(
584 auto *oldOutputOp = hwModuleOp.getBodyBlock()->getTerminator();
585 b.create<hw::OutputOp>(loc, outputPortAssignments);
586 oldOutputOp->erase();
595 MachineOpConverter::convertTransitions(
596 StateOp currentState, ArrayRef<TransitionOp> transitions) {
598 if (transitions.empty()) {
601 nextState = encoding->encode(currentState);
604 auto transition = cast<fsm::TransitionOp>(transitions.front());
605 nextState = encoding->encode(transition.getNextStateOp());
608 if (transition.hasAction()) {
611 auto actionMoveOpsRes =
612 moveOps(&transition.getAction().front(),
613 [](Operation *op) { return isa<fsm::UpdateOp>(op); });
614 if (failed(actionMoveOpsRes))
618 DenseMap<fsm::VariableOp, Value> variableUpdates;
619 for (
auto updateOp : transition.getAction().getOps<fsm::UpdateOp>()) {
620 VariableOp variableOp = updateOp.getVariableOp();
621 variableUpdates[variableOp] = updateOp.getValue();
624 stateToVariableUpdates[currentState][transition.getNextStateOp()] =
629 if (transition.hasGuard()) {
632 auto guardOpRes = moveOps(&transition.getGuard().front());
633 if (failed(guardOpRes))
636 auto guardOp = cast<ReturnOp>(*guardOpRes);
637 assert(guardOp &&
"guard should be defined");
638 auto guard = guardOp.getOperand();
639 auto otherNextState =
640 convertTransitions(currentState, transitions.drop_front());
641 if (failed(otherNextState))
644 transition.getLoc(), guard, nextState, *otherNextState,
false);
645 nextState = nextStateMux;
649 assert(nextState &&
"next state should be defined");
653 FailureOr<MachineOpConverter::StateConversionResult>
654 MachineOpConverter::convertState(
StateOp state) {
655 MachineOpConverter::StateConversionResult res;
659 if (!state.getOutput().empty()) {
660 auto outputOpRes = moveOps(&state.getOutput().front());
661 if (failed(outputOpRes))
664 OutputOp outputOp = cast<fsm::OutputOp>(*outputOpRes);
665 res.outputs = outputOp.getOperands();
668 auto transitions = llvm::SmallVector<TransitionOp>(
672 auto nextStateRes = convertTransitions(state, transitions);
673 if (failed(nextStateRes))
675 res.nextState = *nextStateRes;
679 struct FSMToSVPass :
public circt::impl::ConvertFSMToSVBase<FSMToSVPass> {
680 void runOnOperation()
override;
683 void FSMToSVPass::runOnOperation() {
684 auto module = getOperation();
685 auto loc = module.getLoc();
686 auto b = OpBuilder(module);
689 auto machineOps = llvm::to_vector(module.getOps<
MachineOp>());
690 if (machineOps.empty()) {
691 markAllAnalysesPreserved();
698 b.setInsertionPointToStart(module.getBody());
701 typeScope.getBodyRegion().push_back(
new Block());
703 auto file = b.
create<emit::FileOp>(loc,
"fsm_enum_typedefs.sv", [&] {
704 b.
create<emit::RefOp>(loc,
707 auto fragment = b.create<emit::FragmentOp>(loc,
"FSM_ENUM_TYPEDEFS", [&] {
708 b.create<sv::VerbatimOp>(loc,
"`include \"" + file.getFileName() +
"\"");
714 for (
auto machineOp : machineOps) {
715 MachineOpConverter converter(b, typeScope, machineOp, headerName);
717 if (failed(converter.dispatch())) {
724 llvm::SmallVector<HWInstanceOp> instances;
725 module.walk([&](HWInstanceOp instance) { instances.push_back(instance); });
726 for (
auto instance : instances) {
730 "FSM machine should have been converted to a hw.module");
732 b.setInsertionPoint(instance);
733 llvm::SmallVector<Value, 4> operands;
734 llvm::transform(instance.getOperands(), std::back_inserter(operands),
735 [&](
auto operand) { return operand; });
736 auto hwInstance = b.create<hw::InstanceOp>(
737 instance.getLoc(), fsmHWModule, b.getStringAttr(instance.getName()),
739 instance.replaceAllUsesWith(hwInstance);
743 assert(!typeScope.getBodyBlock()->empty() &&
"missing type decls");
749 return std::make_unique<FSMToSVPass>();
assert(baseType &&"element must be base type")
static ClkRstIdxs getMachinePortInfo(SmallVectorImpl< hw::PortInfo > &ports, MachineOp machine, OpBuilder &b)
static void cloneConstantsIntoRegion(Region ®ion, OpBuilder &builder)
def create(data_type, value)
def create(str sym_name, Type type, str verilog_name=None)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
StringRef getFragmentsAttrName()
Return the name of the fragments array attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createConvertFSMToSVPass()