10 #include "../PassDetail.h"
18 #include "mlir/Transforms/RegionUtils.h"
19 #include "llvm/ADT/TypeSwitch.h"
25 using namespace circt;
38 machine.getHWPortInfo(ports);
39 ClkRstIdxs specialPorts;
43 clock.name = b.getStringAttr(
"clk");
46 clock.argNum = machine.getNumArguments();
47 ports.push_back(clock);
48 specialPorts.clockIdx = clock.argNum;
52 reset.name = b.getStringAttr(
"rst");
54 reset.type = b.getI1Type();
55 reset.argNum = machine.getNumArguments() + 1;
56 ports.push_back(reset);
57 specialPorts.resetIdx = reset.argNum;
65 llvm::SetVector<Value> captures;
66 getUsedValuesDefinedAbove(region, region, captures);
68 OpBuilder::InsertionGuard guard(
builder);
69 builder.setInsertionPointToStart(®ion.front());
72 for (
auto &capture : captures) {
73 Operation *op = capture.getDefiningOp();
74 if (!op || !op->hasTrait<OpTrait::ConstantLike>())
77 Operation *cloned =
builder.clone(*op);
78 for (
auto [orig, replacement] :
79 llvm::zip(op->getResults(), cloned->getResults()))
80 replaceAllUsesInRegionWith(orig, replacement, region);
101 Type getStateType() {
return stateType; }
104 std::unique_ptr<sv::CasePattern> getCasePattern(
StateOp state);
113 void setEncoding(
StateOp state, Value v,
bool wire =
false);
137 : typeScope(typeScope), b(b), machine(machine), hwModule(hwModule) {
138 Location loc = machine.getLoc();
139 llvm::SmallVector<Attribute> stateNames;
141 for (
auto state : machine.getBody().getOps<
StateOp>())
142 stateNames.push_back(b.getStringAttr(state.getName()));
148 OpBuilder::InsertionGuard guard(b);
149 b.setInsertionPointToStart(&typeScope.getBodyRegion().front());
151 loc, b.getStringAttr(hwModule.getName() +
"_state_t"),
156 {FlatSymbolRefAttr::get(typedeclEnumType)}),
160 b.setInsertionPointToStart(&hwModule.getBody().front());
161 for (
auto state : machine.getBody().getOps<
StateOp>()) {
163 loc, b.getStringAttr(state.getName()), stateType);
164 auto enumConstantOp = b.create<hw::EnumConstantOp>(
165 loc, fieldAttr.getType().getValue(), fieldAttr);
166 setEncoding(state, enumConstantOp,
172 Value StateEncoding::encode(
StateOp state) {
173 auto it = stateToValue.find(state);
174 assert(it != stateToValue.end() &&
"state not found");
178 StateOp StateEncoding::decode(Value value) {
179 auto it = valueToState.find(value);
180 assert(it != valueToState.end() &&
"encoded state not found");
185 std::unique_ptr<sv::CasePattern> StateEncoding::getCasePattern(
StateOp state) {
188 cast<hw::EnumConstantOp>(valueToSrcValue[encode(state)].getDefiningOp())
190 return std::make_unique<sv::CaseEnumPattern>(fieldAttr);
193 void StateEncoding::setEncoding(
StateOp state, Value v,
bool wire) {
194 assert(stateToValue.find(state) == stateToValue.end() &&
195 "state already encoded");
199 auto loc = machine.getLoc();
200 auto stateType = getStateType();
201 auto stateEncodingWire = b.create<
sv::RegOp>(
202 loc, stateType, b.getStringAttr(
"to_" + state.getName()),
208 stateToValue[state] = encodedValue;
209 valueToState[encodedValue] = state;
210 valueToSrcValue[encodedValue] = v;
213 class MachineOpConverter {
216 MachineOp machineOp, FlatSymbolRefAttr headerName)
217 : machineOp(machineOp), typeScope(typeScope), b(
builder),
218 headerName(headerName) {}
238 LogicalResult dispatch();
241 struct StateConversionResult {
245 llvm::SmallVector<Value>
outputs;
248 using StateConversionResults = DenseMap<StateOp, StateConversionResult>;
259 ArrayRef<TransitionOp> transitions);
266 moveOps(Block *block,
267 llvm::function_ref<
bool(Operation *)> exclude =
nullptr);
270 using StateCaseMapping =
272 std::variant<Value, std::shared_ptr<CaseMuxItem>>>;
283 StateCaseMapping assignmentInState;
287 std::optional<Value> defaultValue = {};
293 void buildStateCaseMux(llvm::MutableArrayRef<CaseMuxItem> assignments);
296 std::unique_ptr<StateEncoding> encoding;
299 llvm::SmallVector<StateOp> orderedStates;
310 llvm::DenseMap< VariableOp, Value>>>
311 stateToVariableUpdates;
327 FlatSymbolRefAttr headerName;
331 MachineOpConverter::moveOps(Block *block,
332 llvm::function_ref<
bool(Operation *)> exclude) {
333 for (
auto &op : llvm::make_early_inc_range(*block)) {
334 if (!isa<comb::CombDialect, hw::HWDialect, fsm::FSMDialect>(
336 return op.emitOpError()
337 <<
"is unsupported (op from the "
338 << op.getDialect()->getNamespace() <<
" dialect).";
340 if (exclude && exclude(&op))
343 if (op.hasTrait<OpTrait::IsTerminator>())
346 op.moveBefore(hwModuleOp.getBodyBlock(), b.getInsertionPoint());
351 void MachineOpConverter::buildStateCaseMux(
352 llvm::MutableArrayRef<CaseMuxItem> assignments) {
356 Value select = assignments.front().select;
359 [&](
const CaseMuxItem &item) {
return item.select == select; }) &&
360 "All assignments must use the same select signal.");
364 for (
auto &assignment : assignments) {
365 if (assignment.defaultValue)
366 b.create<sv::BPAssignOp>(assignment.wire.getLoc(), assignment.wire,
367 *assignment.defaultValue);
371 caseMux = b.create<sv::CaseOp>(
372 machineOp.getLoc(), CaseStmtType::CaseStmt,
374 machineOp.getNumStates() + 1, [&](
size_t caseIdx) {
376 if (caseIdx == machineOp.getNumStates())
377 return std::unique_ptr<sv::CasePattern>(
378 new sv::CaseDefaultPattern(b.getContext()));
379 StateOp state = orderedStates[caseIdx];
380 return encoding->getCasePattern(state);
384 for (
auto assignment : assignments) {
385 OpBuilder::InsertionGuard g(b);
386 for (
auto [caseInfo, stateOp] :
387 llvm::zip(caseMux.getCases(), orderedStates)) {
388 auto assignmentInState = assignment.assignmentInState.find(stateOp);
389 if (assignmentInState == assignment.assignmentInState.end())
391 b.setInsertionPointToEnd(caseInfo.block);
392 if (
auto v = std::get_if<Value>(&assignmentInState->second); v) {
393 b.create<sv::BPAssignOp>(machineOp.getLoc(), assignment.wire, *v);
396 llvm::SmallVector<CaseMuxItem, 4> nestedAssignments;
397 nestedAssignments.push_back(
398 *
std::get<std::shared_ptr<CaseMuxItem>>(assignmentInState->second));
399 buildStateCaseMux(nestedAssignments);
405 LogicalResult MachineOpConverter::dispatch() {
406 b.setInsertionPoint(machineOp);
407 auto loc = machineOp.getLoc();
408 if (machineOp.getNumStates() < 2)
409 return machineOp.emitOpError() <<
"expected at least 2 states.";
416 SmallVector<hw::PortInfo, 16> ports;
418 hwModuleOp = b.create<
hw::HWModuleOp>(loc, machineOp.getSymNameAttr(), ports);
420 b.getArrayAttr({headerName}));
421 b.setInsertionPointToStart(hwModuleOp.getBodyBlock());
425 for (
auto args : llvm::zip(machineOp.getArguments(),
426 hwModuleOp.getBodyBlock()->getArguments())) {
427 auto machineArg = std::get<0>(args);
428 auto hwModuleArg = std::get<1>(args);
429 machineArg.replaceAllUsesWith(hwModuleArg);
432 auto clock = hwModuleOp.getBodyBlock()->getArgument(clkRstIdxs.clockIdx);
433 auto reset = hwModuleOp.getBodyBlock()->getArgument(clkRstIdxs.resetIdx);
437 std::make_unique<StateEncoding>(b, typeScope, machineOp, hwModuleOp);
438 auto stateType = encoding->getStateType();
442 auto nextStateWireRead = b.create<
sv::ReadInOutOp>(loc, nextStateWire);
444 loc, nextStateWireRead, clock, reset,
445 encoding->encode(machineOp.getInitialStateOp()),
448 llvm::DenseMap<VariableOp, sv::RegOp> variableNextStateWires;
449 for (
auto variableOp : machineOp.front().getOps<fsm::VariableOp>()) {
450 auto initValueAttr = variableOp.getInitValueAttr().dyn_cast<IntegerAttr>();
452 return variableOp.emitOpError() <<
"expected an integer attribute "
453 "for the initial value.";
454 Type varType = variableOp.getType();
455 auto varLoc = variableOp.getLoc();
457 varLoc, varType, b.getStringAttr(variableOp.getName() +
"_next"));
458 auto varResetVal = b.create<
hw::ConstantOp>(varLoc, initValueAttr);
461 varResetVal, b.getStringAttr(variableOp.getName() +
"_reg"));
462 variableToRegister[variableOp] = variableReg;
463 variableNextStateWires[variableOp] = varNextState;
471 if (failed(moveOps(&machineOp.front(), [](Operation *op) {
472 return isa<fsm::StateOp, fsm::VariableOp>(op);
477 StateCaseMapping nextStateFromState;
478 StateConversionResults stateConvResults;
479 for (
auto state : machineOp.getBody().getOps<
StateOp>()) {
480 auto stateConvRes = convertState(state);
481 if (failed(stateConvRes))
484 stateConvResults[state] = *stateConvRes;
485 orderedStates.push_back(state);
486 nextStateFromState[state] = {stateConvRes->nextState};
490 llvm::SmallVector<CaseMuxItem, 4> outputCaseAssignments;
491 auto hwPortList = hwModuleOp.getPortList();
492 size_t portIndex = 0;
493 for (
auto &port : hwPortList) {
494 if (!port.isOutput())
496 auto outputPortType = port.type;
497 CaseMuxItem outputAssignment;
498 outputAssignment.wire = b.create<
sv::RegOp>(
499 machineOp.getLoc(), outputPortType,
500 b.getStringAttr(
"output_" + std::to_string(portIndex)));
501 outputAssignment.select = stateReg;
502 for (
auto &state : orderedStates)
503 outputAssignment.assignmentInState[state] = {
504 stateConvResults[state].outputs[portIndex]};
506 outputCaseAssignments.push_back(outputAssignment);
511 llvm::DenseMap<VariableOp, CaseMuxItem> variableCaseMuxItems;
512 for (
auto &[currentState, it] : stateToVariableUpdates) {
513 for (
auto &[targetState, it2] : it) {
514 for (
auto &[variableOp, targetValue] : it2) {
515 auto caseMuxItemIt = variableCaseMuxItems.find(variableOp);
516 if (caseMuxItemIt == variableCaseMuxItems.end()) {
520 variableCaseMuxItems[variableOp];
521 caseMuxItemIt = variableCaseMuxItems.find(variableOp);
523 assert(variableNextStateWires.count(variableOp));
524 caseMuxItemIt->second.wire = variableNextStateWires[variableOp];
525 caseMuxItemIt->second.select = stateReg;
526 caseMuxItemIt->second.defaultValue =
527 variableToRegister[variableOp].getResult();
530 if (!std::get_if<std::shared_ptr<CaseMuxItem>>(
531 &caseMuxItemIt->second.assignmentInState[currentState])) {
534 CaseMuxItem innerCaseMuxItem;
535 innerCaseMuxItem.wire = caseMuxItemIt->second.wire;
536 innerCaseMuxItem.select = nextStateWireRead;
537 caseMuxItemIt->second.assignmentInState[currentState] = {
538 std::make_shared<CaseMuxItem>(innerCaseMuxItem)};
544 auto &innerCaseMuxItem = std::get<std::shared_ptr<CaseMuxItem>>(
545 caseMuxItemIt->second.assignmentInState[currentState]);
546 innerCaseMuxItem->assignmentInState[targetState] = {targetValue};
552 llvm::SmallVector<CaseMuxItem, 4> nextStateCaseAssignments;
553 nextStateCaseAssignments.push_back(
554 CaseMuxItem{nextStateWire, stateReg, nextStateFromState});
555 for (
auto &[_, caseMuxItem] : variableCaseMuxItems)
556 nextStateCaseAssignments.push_back(caseMuxItem);
557 nextStateCaseAssignments.append(outputCaseAssignments.begin(),
558 outputCaseAssignments.end());
561 auto alwaysCombOp = b.create<sv::AlwaysCombOp>(loc);
562 OpBuilder::InsertionGuard g(b);
563 b.setInsertionPointToStart(alwaysCombOp.getBodyBlock());
564 buildStateCaseMux(nextStateCaseAssignments);
568 for (
auto &[variableOp, variableReg] : variableToRegister)
569 variableOp.getResult().replaceAllUsesWith(variableReg);
572 llvm::SmallVector<Value> outputPortAssignments;
573 for (
auto outputAssignment : outputCaseAssignments)
574 outputPortAssignments.push_back(
579 auto *oldOutputOp = hwModuleOp.getBodyBlock()->getTerminator();
580 b.create<hw::OutputOp>(loc, outputPortAssignments);
581 oldOutputOp->erase();
590 MachineOpConverter::convertTransitions(
591 StateOp currentState, ArrayRef<TransitionOp> transitions) {
593 if (transitions.empty()) {
596 nextState = encoding->encode(currentState);
599 auto transition = cast<fsm::TransitionOp>(transitions.front());
600 nextState = encoding->encode(transition.getNextStateOp());
603 if (transition.hasAction()) {
606 auto actionMoveOpsRes =
607 moveOps(&transition.getAction().front(),
608 [](Operation *op) { return isa<fsm::UpdateOp>(op); });
609 if (failed(actionMoveOpsRes))
613 DenseMap<fsm::VariableOp, Value> variableUpdates;
614 for (
auto updateOp : transition.getAction().getOps<fsm::UpdateOp>()) {
615 VariableOp variableOp = updateOp.getVariableOp();
616 variableUpdates[variableOp] = updateOp.getValue();
619 stateToVariableUpdates[currentState][transition.getNextStateOp()] =
624 if (transition.hasGuard()) {
627 auto guardOpRes = moveOps(&transition.getGuard().front());
628 if (failed(guardOpRes))
631 auto guardOp = cast<ReturnOp>(*guardOpRes);
632 assert(guardOp &&
"guard should be defined");
633 auto guard = guardOp.getOperand();
634 auto otherNextState =
635 convertTransitions(currentState, transitions.drop_front());
636 if (failed(otherNextState))
639 transition.getLoc(), guard, nextState, *otherNextState,
false);
640 nextState = nextStateMux;
644 assert(nextState &&
"next state should be defined");
649 MachineOpConverter::convertState(
StateOp state) {
650 MachineOpConverter::StateConversionResult res;
654 if (!state.getOutput().empty()) {
655 auto outputOpRes = moveOps(&state.getOutput().front());
656 if (failed(outputOpRes))
659 OutputOp outputOp = cast<fsm::OutputOp>(*outputOpRes);
660 res.outputs = outputOp.getOperands();
663 auto transitions = llvm::SmallVector<TransitionOp>(
667 auto nextStateRes = convertTransitions(state, transitions);
668 if (failed(nextStateRes))
670 res.nextState = *nextStateRes;
674 struct FSMToSVPass :
public ConvertFSMToSVBase<FSMToSVPass> {
675 void runOnOperation()
override;
678 void FSMToSVPass::runOnOperation() {
679 auto module = getOperation();
680 auto loc = module.getLoc();
681 auto b = OpBuilder(module);
684 auto machineOps = llvm::to_vector(module.getOps<
MachineOp>());
685 if (machineOps.empty()) {
686 markAllAnalysesPreserved();
693 b.setInsertionPointToStart(module.getBody());
696 typeScope.getBodyRegion().push_back(
new Block());
698 auto file = b.
create<emit::FileOp>(loc,
"fsm_enum_typedefs.sv", [&] {
699 b.
create<emit::RefOp>(loc,
702 auto fragment = b.create<emit::FragmentOp>(loc,
"FSM_ENUM_TYPEDEFS", [&] {
703 b.create<sv::VerbatimOp>(loc,
"`include \"" + file.getFileName() +
"\"");
709 for (
auto machineOp : machineOps) {
710 MachineOpConverter converter(b, typeScope, machineOp, headerName);
712 if (failed(converter.dispatch())) {
719 llvm::SmallVector<HWInstanceOp> instances;
720 module.walk([&](HWInstanceOp instance) { instances.push_back(instance); });
721 for (
auto instance : instances) {
725 "FSM machine should have been converted to a hw.module");
727 b.setInsertionPoint(instance);
728 llvm::SmallVector<Value, 4> operands;
729 llvm::transform(instance.getOperands(), std::back_inserter(operands),
730 [&](
auto operand) { return operand; });
731 auto hwInstance = b.create<hw::InstanceOp>(
732 instance.getLoc(), fsmHWModule, b.getStringAttr(instance.getName()),
734 instance.replaceAllUsesWith(hwInstance);
738 assert(!typeScope.getBodyBlock()->empty() &&
"missing type decls");
744 return std::make_unique<FSMToSVPass>();
assert(baseType &&"element must be base type")
static ClkRstIdxs getMachinePortInfo(SmallVectorImpl< hw::PortInfo > &ports, MachineOp machine, OpBuilder &b)
static void cloneConstantsIntoRegion(Region ®ion, OpBuilder &builder)
llvm::SmallVector< StringAttr > outputs
def create(data_type, value)
def create(str sym_name, Type type, str verilog_name=None)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
StringRef getFragmentsAttrName()
Return the name of the fragments array attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createConvertFSMToSVPass()