10 #include "../PassDetail.h"
17 #include "mlir/Transforms/RegionUtils.h"
18 #include "llvm/ADT/TypeSwitch.h"
24 using namespace circt;
35 MachineOp machine, OpBuilder &b) {
37 machine.getHWPortInfo(ports);
38 ClkRstIdxs specialPorts;
42 clock.name = b.getStringAttr(
"clk");
45 clock.argNum = machine.getNumArguments();
46 ports.push_back(clock);
47 specialPorts.clockIdx = clock.argNum;
51 reset.name = b.getStringAttr(
"rst");
53 reset.type = b.getI1Type();
54 reset.argNum = machine.getNumArguments() + 1;
55 ports.push_back(reset);
56 specialPorts.resetIdx = reset.argNum;
64 llvm::SetVector<Value> captures;
65 getUsedValuesDefinedAbove(region, region, captures);
67 OpBuilder::InsertionGuard guard(
builder);
68 builder.setInsertionPointToStart(®ion.front());
71 for (
auto &capture : captures) {
72 Operation *op = capture.getDefiningOp();
73 if (!op || !op->hasTrait<OpTrait::ConstantLike>())
76 Operation *cloned =
builder.clone(*op);
77 for (
auto [orig, replacement] :
78 llvm::zip(op->getResults(), cloned->getResults()))
79 replaceAllUsesInRegionWith(orig, replacement, region);
91 StateEncoding(OpBuilder &b, hw::TypeScopeOp typeScope, MachineOp machine,
92 hw::HWModuleOp hwModule);
95 Value encode(StateOp state);
97 StateOp decode(Value value);
100 Type getStateType() {
return stateType; }
103 std::unique_ptr<sv::CasePattern> getCasePattern(StateOp state);
112 void setEncoding(StateOp state, Value v,
bool wire =
false);
124 hw::TypeScopeOp typeScope;
131 hw::HWModuleOp hwModule;
134 StateEncoding::StateEncoding(OpBuilder &b, hw::TypeScopeOp typeScope,
135 MachineOp machine, hw::HWModuleOp hwModule)
136 : typeScope(typeScope), b(b), machine(machine), hwModule(hwModule) {
137 Location loc = machine.getLoc();
138 llvm::SmallVector<Attribute> stateNames;
140 for (
auto state : machine.getBody().getOps<StateOp>())
141 stateNames.push_back(b.getStringAttr(state.getName()));
147 OpBuilder::InsertionGuard guard(b);
148 b.setInsertionPointToStart(&typeScope.getBodyRegion().front());
149 auto typedeclEnumType = b.create<hw::TypedeclOp>(
150 loc, b.getStringAttr(hwModule.getName() +
"_state_t"),
155 {FlatSymbolRefAttr::get(typedeclEnumType)}),
159 b.setInsertionPointToStart(&hwModule.getBody().front());
160 for (
auto state : machine.getBody().getOps<StateOp>()) {
162 loc, b.getStringAttr(state.getName()), stateType);
163 auto enumConstantOp = b.create<hw::EnumConstantOp>(
164 loc, fieldAttr.getType().getValue(), fieldAttr);
165 setEncoding(state, enumConstantOp,
171 Value StateEncoding::encode(StateOp state) {
172 auto it = stateToValue.find(state);
173 assert(it != stateToValue.end() &&
"state not found");
177 StateOp StateEncoding::decode(Value value) {
178 auto it = valueToState.find(value);
179 assert(it != valueToState.end() &&
"encoded state not found");
184 std::unique_ptr<sv::CasePattern> StateEncoding::getCasePattern(StateOp state) {
187 cast<hw::EnumConstantOp>(valueToSrcValue[encode(state)].getDefiningOp())
189 return std::make_unique<sv::CaseEnumPattern>(fieldAttr);
192 void StateEncoding::setEncoding(StateOp state, Value v,
bool wire) {
193 assert(stateToValue.find(state) == stateToValue.end() &&
194 "state already encoded");
198 auto loc = machine.getLoc();
199 auto stateType = getStateType();
200 auto stateEncodingWire = b.create<sv::RegOp>(
201 loc, stateType, b.getStringAttr(
"to_" + state.getName()),
203 b.create<sv::AssignOp>(loc, stateEncodingWire, v);
204 encodedValue = b.create<sv::ReadInOutOp>(loc, stateEncodingWire);
207 stateToValue[state] = encodedValue;
208 valueToState[encodedValue] = state;
209 valueToSrcValue[encodedValue] = v;
212 class MachineOpConverter {
214 MachineOpConverter(OpBuilder &
builder, hw::TypeScopeOp typeScope,
216 : machineOp(machineOp), typeScope(typeScope), b(
builder) {}
236 LogicalResult dispatch();
239 struct StateConversionResult {
243 llvm::SmallVector<Value>
outputs;
246 using StateConversionResults = DenseMap<StateOp, StateConversionResult>;
257 ArrayRef<TransitionOp> transitions);
264 moveOps(Block *block,
265 llvm::function_ref<
bool(Operation *)> exclude =
nullptr);
268 using StateCaseMapping =
270 std::variant<Value, std::shared_ptr<CaseMuxItem>>>;
281 StateCaseMapping assignmentInState;
285 std::optional<Value> defaultValue = {};
291 void buildStateCaseMux(llvm::MutableArrayRef<CaseMuxItem> assignments);
294 std::unique_ptr<StateEncoding> encoding;
297 llvm::SmallVector<StateOp> orderedStates;
308 llvm::DenseMap< VariableOp, Value>>>
309 stateToVariableUpdates;
315 hw::HWModuleOp hwModuleOp;
318 seq::CompRegOp stateReg;
321 hw::TypeScopeOp typeScope;
327 MachineOpConverter::moveOps(Block *block,
328 llvm::function_ref<
bool(Operation *)> exclude) {
329 for (
auto &op : llvm::make_early_inc_range(*block)) {
330 if (!isa<comb::CombDialect, hw::HWDialect, fsm::FSMDialect>(
332 return op.emitOpError()
333 <<
"is unsupported (op from the "
334 << op.getDialect()->getNamespace() <<
" dialect).";
336 if (exclude && exclude(&op))
339 if (op.hasTrait<OpTrait::IsTerminator>())
342 op.moveBefore(hwModuleOp.getBodyBlock(), b.getInsertionPoint());
347 void MachineOpConverter::buildStateCaseMux(
348 llvm::MutableArrayRef<CaseMuxItem> assignments) {
352 Value select = assignments.front().select;
355 [&](
const CaseMuxItem &item) {
return item.select == select; }) &&
356 "All assignments must use the same select signal.");
360 for (
auto &assignment : assignments) {
361 if (assignment.defaultValue)
362 b.create<sv::BPAssignOp>(assignment.wire.getLoc(), assignment.wire,
363 *assignment.defaultValue);
367 caseMux = b.create<sv::CaseOp>(
368 machineOp.getLoc(), CaseStmtType::CaseStmt,
370 machineOp.getNumStates() + 1, [&](
size_t caseIdx) {
372 if (caseIdx == machineOp.getNumStates())
373 return std::unique_ptr<sv::CasePattern>(
374 new sv::CaseDefaultPattern(b.getContext()));
375 StateOp state = orderedStates[caseIdx];
376 return encoding->getCasePattern(state);
380 for (
auto assignment : assignments) {
381 OpBuilder::InsertionGuard g(b);
382 for (
auto [caseInfo, stateOp] :
383 llvm::zip(caseMux.getCases(), orderedStates)) {
384 auto assignmentInState = assignment.assignmentInState.find(stateOp);
385 if (assignmentInState == assignment.assignmentInState.end())
387 b.setInsertionPointToEnd(caseInfo.block);
388 if (
auto v = std::get_if<Value>(&assignmentInState->second); v) {
389 b.create<sv::BPAssignOp>(machineOp.getLoc(), assignment.wire, *v);
392 llvm::SmallVector<CaseMuxItem, 4> nestedAssignments;
393 nestedAssignments.push_back(
394 *
std::get<std::shared_ptr<CaseMuxItem>>(assignmentInState->second));
395 buildStateCaseMux(nestedAssignments);
401 LogicalResult MachineOpConverter::dispatch() {
402 b.setInsertionPoint(machineOp);
403 auto loc = machineOp.getLoc();
404 if (machineOp.getNumStates() < 2)
405 return machineOp.emitOpError() <<
"expected at least 2 states.";
412 SmallVector<hw::PortInfo, 16> ports;
414 hwModuleOp = b.create<hw::HWModuleOp>(loc, machineOp.getSymNameAttr(), ports);
415 b.setInsertionPointToStart(hwModuleOp.getBodyBlock());
419 for (
auto args : llvm::zip(machineOp.getArguments(),
420 hwModuleOp.getBodyBlock()->getArguments())) {
421 auto machineArg = std::get<0>(args);
422 auto hwModuleArg = std::get<1>(args);
423 machineArg.replaceAllUsesWith(hwModuleArg);
426 auto clock = hwModuleOp.getBodyBlock()->getArgument(clkRstIdxs.clockIdx);
427 auto reset = hwModuleOp.getBodyBlock()->getArgument(clkRstIdxs.resetIdx);
431 std::make_unique<StateEncoding>(b, typeScope, machineOp, hwModuleOp);
432 auto stateType = encoding->getStateType();
435 b.create<sv::RegOp>(loc, stateType, b.getStringAttr(
"state_next"));
436 auto nextStateWireRead = b.create<sv::ReadInOutOp>(loc, nextStateWire);
437 stateReg = b.create<seq::CompRegOp>(
438 loc, stateType, nextStateWireRead, clock,
"state_reg", reset,
439 encoding->encode(machineOp.getInitialStateOp()),
nullptr);
441 llvm::DenseMap<VariableOp, sv::RegOp> variableNextStateWires;
442 for (
auto variableOp : machineOp.front().getOps<fsm::VariableOp>()) {
443 auto initValueAttr = variableOp.getInitValueAttr().dyn_cast<IntegerAttr>();
445 return variableOp.emitOpError() <<
"expected an integer attribute "
446 "for the initial value.";
447 Type varType = variableOp.getType();
448 auto varLoc = variableOp.getLoc();
449 auto varNextState = b.create<sv::RegOp>(
450 varLoc, varType, b.getStringAttr(variableOp.getName() +
"_next"));
451 auto varResetVal = b.create<hw::ConstantOp>(varLoc, initValueAttr);
452 auto variableReg = b.create<seq::CompRegOp>(
453 varLoc, varType, b.create<sv::ReadInOutOp>(varLoc, varNextState), clock,
454 b.getStringAttr(variableOp.getName() +
"_reg"), reset, varResetVal,
456 variableToRegister[variableOp] = variableReg;
457 variableNextStateWires[variableOp] = varNextState;
465 if (failed(moveOps(&machineOp.front(), [](Operation *op) {
466 return isa<fsm::StateOp, fsm::VariableOp>(op);
471 StateCaseMapping nextStateFromState;
472 StateConversionResults stateConvResults;
473 for (
auto state : machineOp.getBody().getOps<StateOp>()) {
474 auto stateConvRes = convertState(state);
475 if (failed(stateConvRes))
478 stateConvResults[state] = *stateConvRes;
479 orderedStates.push_back(state);
480 nextStateFromState[state] = {stateConvRes->nextState};
484 llvm::SmallVector<CaseMuxItem, 4> outputCaseAssignments;
485 auto hwPortList = hwModuleOp.getPortList();
486 for (
size_t portIndex = 0; portIndex < machineOp.getNumResults();
488 auto outputPort = hwPortList.atOutput(portIndex);
489 auto outputPortType = outputPort.type;
490 CaseMuxItem outputAssignment;
491 outputAssignment.wire = b.create<sv::RegOp>(
492 machineOp.getLoc(), outputPortType,
493 b.getStringAttr(
"output_" + std::to_string(portIndex)));
494 outputAssignment.select = stateReg;
495 for (
auto &state : orderedStates)
496 outputAssignment.assignmentInState[state] = {
497 stateConvResults[state].outputs[portIndex]};
499 outputCaseAssignments.push_back(outputAssignment);
503 llvm::DenseMap<VariableOp, CaseMuxItem> variableCaseMuxItems;
504 for (
auto &[currentState, it] : stateToVariableUpdates) {
505 for (
auto &[targetState, it2] : it) {
506 for (
auto &[variableOp, targetValue] : it2) {
507 auto caseMuxItemIt = variableCaseMuxItems.find(variableOp);
508 if (caseMuxItemIt == variableCaseMuxItems.end()) {
512 variableCaseMuxItems[variableOp];
513 caseMuxItemIt = variableCaseMuxItems.find(variableOp);
515 assert(variableNextStateWires.count(variableOp));
516 caseMuxItemIt->second.wire = variableNextStateWires[variableOp];
517 caseMuxItemIt->second.select = stateReg;
518 caseMuxItemIt->second.defaultValue =
519 variableToRegister[variableOp].getResult();
522 if (!std::get_if<std::shared_ptr<CaseMuxItem>>(
523 &caseMuxItemIt->second.assignmentInState[currentState])) {
526 CaseMuxItem innerCaseMuxItem;
527 innerCaseMuxItem.wire = caseMuxItemIt->second.wire;
528 innerCaseMuxItem.select = nextStateWireRead;
529 caseMuxItemIt->second.assignmentInState[currentState] = {
530 std::make_shared<CaseMuxItem>(innerCaseMuxItem)};
536 auto &innerCaseMuxItem = std::get<std::shared_ptr<CaseMuxItem>>(
537 caseMuxItemIt->second.assignmentInState[currentState]);
538 innerCaseMuxItem->assignmentInState[targetState] = {targetValue};
544 llvm::SmallVector<CaseMuxItem, 4> nextStateCaseAssignments;
545 nextStateCaseAssignments.push_back(
546 CaseMuxItem{nextStateWire, stateReg, nextStateFromState});
547 for (
auto &[_, caseMuxItem] : variableCaseMuxItems)
548 nextStateCaseAssignments.push_back(caseMuxItem);
549 nextStateCaseAssignments.append(outputCaseAssignments.begin(),
550 outputCaseAssignments.end());
553 auto alwaysCombOp = b.create<sv::AlwaysCombOp>(loc);
554 OpBuilder::InsertionGuard g(b);
555 b.setInsertionPointToStart(alwaysCombOp.getBodyBlock());
556 buildStateCaseMux(nextStateCaseAssignments);
560 for (
auto &[variableOp, variableReg] : variableToRegister)
561 variableOp.getResult().replaceAllUsesWith(variableReg);
564 llvm::SmallVector<Value> outputPortAssignments;
565 for (
auto outputAssignment : outputCaseAssignments)
566 outputPortAssignments.push_back(
567 b.create<sv::ReadInOutOp>(machineOp.getLoc(), outputAssignment.wire));
571 auto *oldOutputOp = hwModuleOp.getBodyBlock()->getTerminator();
572 b.create<hw::OutputOp>(loc, outputPortAssignments);
573 oldOutputOp->erase();
582 MachineOpConverter::convertTransitions(
583 StateOp currentState, ArrayRef<TransitionOp> transitions) {
585 if (transitions.empty()) {
588 nextState = encoding->encode(currentState);
591 auto transition = cast<fsm::TransitionOp>(transitions.front());
592 nextState = encoding->encode(transition.getNextStateOp());
595 if (transition.hasAction()) {
598 auto actionMoveOpsRes =
599 moveOps(&transition.getAction().front(),
600 [](Operation *op) { return isa<fsm::UpdateOp>(op); });
601 if (failed(actionMoveOpsRes))
605 DenseMap<fsm::VariableOp, Value> variableUpdates;
606 for (
auto updateOp : transition.getAction().getOps<fsm::UpdateOp>()) {
607 VariableOp variableOp = updateOp.getVariableOp();
608 variableUpdates[variableOp] = updateOp.getValue();
611 stateToVariableUpdates[currentState][transition.getNextStateOp()] =
616 if (transition.hasGuard()) {
619 auto guardOpRes = moveOps(&transition.getGuard().front());
620 if (failed(guardOpRes))
623 auto guardOp = cast<ReturnOp>(*guardOpRes);
624 assert(guardOp &&
"guard should be defined");
625 auto guard = guardOp.getOperand();
626 auto otherNextState =
627 convertTransitions(currentState, transitions.drop_front());
628 if (failed(otherNextState))
630 comb::MuxOp nextStateMux = b.create<comb::MuxOp>(
631 transition.getLoc(), guard, nextState, *otherNextState,
false);
632 nextState = nextStateMux;
636 assert(nextState &&
"next state should be defined");
641 MachineOpConverter::convertState(StateOp state) {
642 MachineOpConverter::StateConversionResult res;
646 if (!state.getOutput().empty()) {
647 auto outputOpRes = moveOps(&state.getOutput().front());
648 if (failed(outputOpRes))
651 OutputOp outputOp = cast<fsm::OutputOp>(*outputOpRes);
652 res.outputs = outputOp.getOperands();
655 auto transitions = llvm::SmallVector<TransitionOp>(
656 state.getTransitions().getOps<TransitionOp>());
659 auto nextStateRes = convertTransitions(state, transitions);
660 if (failed(nextStateRes))
662 res.nextState = *nextStateRes;
666 struct FSMToSVPass :
public ConvertFSMToSVBase<FSMToSVPass> {
667 void runOnOperation()
override;
670 void FSMToSVPass::runOnOperation() {
671 auto module = getOperation();
672 auto b = OpBuilder(module);
673 SmallVector<Operation *, 16> opToErase;
678 StringAttr typeScopeFilename = b.getStringAttr(
"fsm_enum_typedefs.sv");
679 b.setInsertionPointToStart(module.getBody());
680 auto typeScope = b.create<hw::TypeScopeOp>(
681 module.getLoc(), b.getStringAttr(
"fsm_enum_typedecls"));
682 typeScope.getBodyRegion().push_back(
new Block());
686 b.getBoolAttr(
false),
687 b.getBoolAttr(
false)));
690 for (
auto machine : llvm::make_early_inc_range(module.getOps<MachineOp>())) {
691 MachineOpConverter converter(b, typeScope, machine);
693 if (failed(converter.dispatch())) {
700 llvm::SmallVector<HWInstanceOp> instances;
701 module.walk([&](HWInstanceOp instance) { instances.push_back(instance); });
702 for (
auto instance : instances) {
704 module.lookupSymbol<hw::HWModuleOp>(instance.getMachine());
706 "FSM machine should have been converted to a hw.module");
708 b.setInsertionPoint(instance);
709 llvm::SmallVector<Value, 4> operands;
710 llvm::transform(instance.getOperands(), std::back_inserter(operands),
711 [&](
auto operand) { return operand; });
712 auto hwInstance = b.create<hw::InstanceOp>(
713 instance.getLoc(), fsmHWModule, b.getStringAttr(instance.getName()),
715 instance.replaceAllUsesWith(hwInstance);
719 if (typeScope.getBodyBlock()->empty()) {
725 b.setInsertionPointToStart(module.getBody());
726 b.create<sv::VerbatimOp>(
727 module.getLoc(),
"`include \"" + typeScopeFilename.getValue() +
"\"");
734 return std::make_unique<FSMToSVPass>();
assert(baseType &&"element must be base type")
static ClkRstIdxs getMachinePortInfo(SmallVectorImpl< hw::PortInfo > &ports, MachineOp machine, OpBuilder &b)
static void cloneConstantsIntoRegion(Region ®ion, OpBuilder &builder)
llvm::SmallVector< StringAttr > outputs
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
std::unique_ptr< mlir::Pass > createConvertFSMToSVPass()