17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/IR/MLIRContext.h"
19#include "mlir/IR/Value.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Transforms/RegionUtils.h"
22#include "llvm/ADT/DenseMap.h"
23#include "llvm/ADT/SmallVector.h"
26#define GEN_PASS_DEF_CONVERTFSMTOCORE
27#include "circt/Conversion/Passes.h.inc"
43 llvm::SetVector<Value> captures;
44 getUsedValuesDefinedAbove(region, region, captures);
46 OpBuilder::InsertionGuard guard(builder);
47 builder.setInsertionPointToStart(®ion.front());
50 for (
auto &capture : captures) {
51 Operation *op = capture.getDefiningOp();
52 if (!op || !op->hasTrait<OpTrait::ConstantLike>())
55 Operation *cloned = builder.clone(*op);
56 for (
auto [orig, replacement] :
57 llvm::zip(op->getResults(), cloned->getResults()))
58 replaceAllUsesInRegionWith(orig, replacement, region);
65 machine.getHWPortInfo(ports);
66 ClkRstIdxs specialPorts;
70 clock.
name = b.getStringAttr(
"clk");
71 clock.
dir = hw::ModulePort::Direction::Input;
72 clock.
type = seq::ClockType::get(b.getContext());
73 clock.
argNum = machine.getNumArguments();
74 ports.push_back(clock);
75 specialPorts.clockIdx = clock.
argNum;
79 reset.
name = b.getStringAttr(
"rst");
80 reset.
dir = hw::ModulePort::Direction::Input;
81 reset.
type = b.getI1Type();
82 reset.
argNum = machine.getNumArguments() + 1;
83 ports.push_back(reset);
84 specialPorts.resetIdx = reset.
argNum;
103 Type getStateType() {
return stateType; }
108 void setEncoding(
StateOp state, Value v);
123StateEncoding::StateEncoding(OpBuilder &b,
MachineOp machine,
125 : b(b), machine(machine), hwModule(hwModule) {
126 Location loc = machine.getLoc();
128 OpBuilder::InsertionGuard guard(b);
129 b.setInsertionPointToStart(&hwModule.getBodyRegion().front());
132 if (machine->getAttr(
"stateType")) {
134 stateType = cast<TypeAttr>(machine->getAttr(
"stateType")).getValue();
136 int numStates = std::distance(machine.getBody().getOps<
StateOp>().begin(),
137 machine.getBody().getOps<
StateOp>().end());
139 IntegerType::get(machine.getContext(), llvm::Log2_64_Ceil(numStates));
143 b.setInsertionPointToStart(&hwModule.getBody().front());
144 for (
auto state : machine.getBody().getOps<
StateOp>()) {
146 setEncoding(state, constantOp);
151Value StateEncoding::encode(
StateOp state) {
152 auto it = stateToValue.find(state);
153 assert(it != stateToValue.end() &&
"state not found");
157void StateEncoding::setEncoding(
StateOp state, Value v) {
158 assert(stateToValue.find(state) == stateToValue.end() &&
159 "state already encoded");
160 stateToValue[state] = v;
161 valueToState[v] = state;
166class MachineOpConverter {
168 MachineOpConverter(OpBuilder &builder,
MachineOp machineOp)
169 : machineOp(machineOp), b(builder),
187 LogicalResult dispatch();
192 LogicalResult convertState(
StateOp state);
198 FailureOr<Value> convertTransitions(
StateOp currentState,
199 ArrayRef<TransitionOp> transitions);
205 FailureOr<Operation *>
206 moveOps(Block *block,
207 llvm::function_ref<
bool(Operation *)> exclude =
nullptr);
209 DenseMap<Value, std::string> backedgeMap;
212 std::unique_ptr<StateEncoding> encoding;
215 llvm::MapVector<VariableOp, seq::CompRegOp> variableToRegister;
219 llvm::MapVector<VariableOp, mlir::Value> variableToMuxChainOut;
222 llvm::SmallVector<mlir::Value> outputMuxChainOuts;
235 mlir::Value stateMuxChainOut;
241LogicalResult MachineOpConverter::dispatch() {
242 b.setInsertionPoint(machineOp);
243 auto loc = machineOp.getLoc();
250 SmallVector<hw::PortInfo, 16> ports;
253 hw::HWModuleOp::create(b, loc, machineOp.getSymNameAttr(), ports);
254 b.setInsertionPointToStart(hwModuleOp.getBodyBlock());
258 for (
auto [machineArg, hwModuleArg] :
259 llvm::zip(machineOp.getArguments(),
261 machineArg.replaceAllUsesWith(hwModuleArg);
264 auto clock = hwModuleOp.getBodyBlock()->getArgument(clkRstIdxs.clockIdx);
265 auto reset = hwModuleOp.getBodyBlock()->getArgument(clkRstIdxs.resetIdx);
269 encoding = std::make_unique<StateEncoding>(b, machineOp, hwModuleOp);
270 auto stateType = encoding->getStateType();
272 Backedge nextStateWire = bb.get(stateType);
274 auto initialStateOp = machineOp.getInitialStateOp();
276 b, loc, nextStateWire, clock, reset,
277 encoding->encode(initialStateOp),
"state_reg",
279 seq::createConstantInitialValue(
280 b, encoding->encode(initialStateOp).getDefiningOp()));
281 stateMuxChainOut = stateReg;
283 llvm::DenseMap<VariableOp, Backedge> variableNextStateWires;
284 for (
auto variableOp : machineOp.front().getOps<
fsm::VariableOp>()) {
285 auto initValueAttr = cast<IntegerAttr>(variableOp.getInitValueAttr());
286 Type varType = variableOp.getType();
287 auto varLoc = variableOp.getLoc();
288 auto nextVariableStateWire = bb.get(varType);
291 b, varLoc, nextVariableStateWire, clock, reset, varResetVal,
292 b.getStringAttr(variableOp.getName()),
293 seq::createConstantInitialValue(b, varResetVal));
294 variableToRegister[variableOp] = variableReg;
295 variableNextStateWires[variableOp] = nextVariableStateWire;
296 variableToMuxChainOut[variableOp] = variableReg;
304 if (failed(moveOps(&machineOp.front(), [](Operation *op) {
305 return isa<fsm::StateOp, fsm::VariableOp>(op);
310 auto hwPortList = hwModuleOp.getPortList();
311 llvm::SmallVector<Backedge> outputBackedges;
312 for (
auto &port : hwPortList)
314 outputMuxChainOuts.push_back(Value());
317 for (
auto state : machineOp.getBody().getOps<
StateOp>()) {
318 auto stateConvRes = convertState(state);
319 if (failed(stateConvRes))
325 nextStateWire.
setValue(stateMuxChainOut);
326 for (
auto [variable, muxChainOut] : variableToMuxChainOut) {
327 variableNextStateWires[variable].setValue(muxChainOut);
331 for (
auto [variableOp, variableReg] : variableToRegister)
332 variableOp.getResult().replaceAllUsesWith(variableReg);
335 llvm::SmallVector<Value> outputValues;
336 for (
auto backedge : outputMuxChainOuts) {
337 outputValues.push_back(backedge);
339 auto *oldOutputOp = hwModuleOp.getBodyBlock()->getTerminator();
340 b.setInsertionPointToEnd(oldOutputOp->getBlock());
341 oldOutputOp->erase();
342 hw::OutputOp::create(b, loc, outputValues);
348MachineOpConverter::convertTransitions(
349 StateOp currentState, ArrayRef<TransitionOp> transitions) {
351 llvm::MapVector<fsm::VariableOp, Value> variableUpdates;
353 comb::ICmpOp::create(b, machineOp.getLoc(), comb::ICmpPredicate::eq,
354 stateReg, encoding->encode(currentState));
355 if (transitions.empty()) {
358 nextState = encoding->encode(currentState);
361 auto transition = cast<fsm::TransitionOp>(transitions.front());
362 nextState = encoding->encode(transition.getNextStateOp());
363 mlir::Value varUpdateCondition;
365 if (transition.hasAction()) {
368 auto actionMoveOpsRes =
369 moveOps(&transition.getAction().front(),
370 [](Operation *op) { return isa<fsm::UpdateOp>(op); });
371 if (failed(actionMoveOpsRes))
375 for (
auto updateOp : transition.getAction().getOps<
fsm::UpdateOp>()) {
376 VariableOp variableOp = updateOp.getVariableOp();
377 variableUpdates[variableOp] = updateOp.getValue();
382 if (transition.hasGuard()) {
385 auto guardOpRes = moveOps(&transition.getGuard().front());
386 if (failed(guardOpRes))
389 auto guardOp = cast<ReturnOp>(*guardOpRes);
390 assert(guardOp &&
"guard should be defined");
391 auto guard = guardOp.getOperand();
392 auto otherNextState =
393 convertTransitions(currentState, transitions.drop_front());
394 if (failed(otherNextState))
397 b, transition.getLoc(), guard, nextState, *otherNextState,
false);
398 nextState = nextStateMux;
400 comb::AndOp::create(b, machineOp.getLoc(), guard, stateCmp);
402 varUpdateCondition = stateCmp;
404 for (
auto variableUpdate : variableUpdates) {
405 auto muxChainOut = variableToMuxChainOut[variableUpdate.first];
406 auto newMuxChainOut =
407 comb::MuxOp::create(b, machineOp.getLoc(), varUpdateCondition,
408 variableUpdate.second, muxChainOut,
false);
409 variableToMuxChainOut[variableUpdate.first] = newMuxChainOut;
413 stateMuxChainOut = comb::MuxOp::create(b, machineOp.getLoc(), stateCmp,
414 nextState, stateMuxChainOut);
415 assert(nextState &&
"next state should be defined");
419FailureOr<Operation *>
420MachineOpConverter::moveOps(Block *block,
421 llvm::function_ref<
bool(Operation *)> exclude) {
422 for (
auto &op :
llvm::make_early_inc_range(*block)) {
423 if (!isa<comb::CombDialect, hw::HWDialect, fsm::FSMDialect>(
427 return op.emitOpError()
428 <<
"is unsupported (op from the "
429 << op.getDialect()->getNamespace() <<
" dialect).";
431 if (exclude && exclude(&op))
434 if (op.hasTrait<OpTrait::IsTerminator>())
437 op.moveBefore(hwModuleOp.getBodyBlock(), b.getInsertionPoint());
442LogicalResult MachineOpConverter::convertState(
StateOp state) {
445 if (!state.getOutput().empty()) {
446 auto outputOpRes = moveOps(&state.getOutput().front());
447 if (failed(outputOpRes))
451 OutputOp outputOp = cast<fsm::OutputOp>(*outputOpRes);
453 comb::ICmpOp::create(b, machineOp.getLoc(), comb::ICmpPredicate::eq,
454 stateReg, encoding->encode(state));
456 for (
auto [i, operand] :
llvm::enumerate(outputOp.getOperands())) {
457 auto muxChainOut = outputMuxChainOuts[i];
461 outputMuxChainOuts[i] = operand;
464 auto muxOp = comb::MuxOp::create(b, machineOp.getLoc(), stateCmp, operand,
466 outputMuxChainOuts[i] = muxOp;
470 auto transitions = llvm::SmallVector<TransitionOp>(
474 auto nextStateRes = convertTransitions(state, transitions);
475 if (failed(nextStateRes))
480struct FSMToCorePass :
public circt::impl::ConvertFSMToCoreBase<FSMToCorePass> {
481 void runOnOperation()
override;
484void FSMToCorePass::runOnOperation() {
485 auto module = getOperation();
486 auto b = OpBuilder(module);
487 SmallVector<Operation *, 16> opToErase;
489 b.setInsertionPointToStart(module.getBody());
491 for (
auto machine :
llvm::make_early_inc_range(module.getOps<
MachineOp>())) {
494 if (machine->getAttr(
"stateType")) {
495 auto stateType = dyn_cast<TypeAttr>(machine->getAttr(
"stateType"));
497 machine->emitError(
"stateType attribute does not name a type");
501 if (!isa<IntegerType>(stateType.getValue())) {
502 machine->emitError(
"stateType attribute must name an integer type");
507 for (
auto variableOp : machine.front().getOps<
fsm::VariableOp>()) {
508 if (!isa<IntegerType>(variableOp.getType())) {
509 variableOp.emitOpError(
510 "only integer variables are currently supported");
516 MachineOpConverter converter(b, machine);
518 if (failed(converter.dispatch())) {
525 llvm::SmallVector<HWInstanceOp> instances;
526 module.walk([&](HWInstanceOp instance) { instances.push_back(instance); });
527 for (
auto instance : instances) {
529 module.lookupSymbol<hw::HWModuleOp>(instance.getMachine());
531 "FSM machine should have been converted to a hw.module");
533 b.setInsertionPoint(instance);
534 llvm::SmallVector<Value, 4> operands;
535 llvm::transform(instance.getOperands(), std::back_inserter(operands),
536 [&](
auto operand) { return operand; });
537 auto hwInstance = hw::InstanceOp::create(
538 b, instance.getLoc(), fsmHWModule, b.getStringAttr(instance.getName()),
540 instance.replaceAllUsesWith(hwInstance);
548 return std::make_unique<FSMToCorePass>();
assert(baseType &&"element must be base type")
static ClkRstIdxs getMachinePortInfo(SmallVectorImpl< hw::PortInfo > &ports, MachineOp machine, OpBuilder &b)
static void cloneConstantsIntoRegion(Region ®ion, OpBuilder &builder)
static Location getLoc(DefSlot slot)
static Block * getBodyBlock(FModuleLike mod)
Instantiate one of these and use it to build typed backedges.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createConvertFSMToCorePass()
This holds the name, type, direction of a module's ports.
size_t argNum
This is the argument index or the result index depending on the direction.